diff --git a/RELEASE.md b/RELEASE.md index 666a7dbc8bb..5e06d0d473e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -58,6 +58,9 @@ * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. +* `tf.distribute`: + * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. + ## Bug Fixes and Other Changes * diff --git a/tensorflow/python/distribute/central_storage_strategy.py b/tensorflow/python/distribute/central_storage_strategy.py index 6e2c441c468..e61570dd6bd 100644 --- a/tensorflow/python/distribute/central_storage_strategy.py +++ b/tensorflow/python/distribute/central_storage_strategy.py @@ -105,52 +105,6 @@ class CentralStorageStrategy(distribute_lib.Strategy): return super(CentralStorageStrategy, self).experimental_distribute_dataset( dataset, options) - def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation - options=None): - """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. - - `dataset_fn` will be called once for each worker in the strategy. In this - case, we only have one worker so `dataset_fn` is called once. Each replica - on this worker will then dequeue a batch of elements from this local - dataset. - - The `dataset_fn` should take an `tf.distribute.InputContext` instance where - information about batching and input replication can be accessed. - - For Example: - ``` - def dataset_fn(input_context): - batch_size = input_context.get_per_replica_batch_size(global_batch_size) - d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) - return d.shard( - input_context.num_input_pipelines, input_context.input_pipeline_id) - - inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) - - for batch in inputs: - replica_results = strategy.run(replica_fn, args=(batch,)) - ``` - - IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a - per-replica batch size, unlike `experimental_distribute_dataset`, which uses - the global batch size. This may be computed using - `input_context.get_per_replica_batch_size`. - - Args: - dataset_fn: A function taking a `tf.distribute.InputContext` instance and - returning a `tf.data.Dataset`. - options: `tf.distribute.InputOptions` used to control options on how this - dataset is distributed. - - Returns: - A "distributed `Dataset`", which the caller can iterate over like regular - datasets. - """ - return super( - CentralStorageStrategy, - self).experimental_distribute_datasets_from_function(dataset_fn, - options) - def experimental_local_results(self, value): # pylint: disable=useless-super-delegation """Returns the list of all local per-replica values contained in `value`. diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index e41fda63fd5..a3e63e8a6f1 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -479,8 +479,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): split_batch_by=self._num_replicas_in_sync, input_context=input_context) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): input_context = self._make_input_context() return input_lib.get_distributed_datasets_from_function( dataset_fn=dataset_fn, diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index a835f5e5ac9..7debd850486 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -386,7 +386,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, def testDistributeDatasetIteratorWithoutFunction(self, distribution): data = [5., 6., 7., 8.] input_iterator = iter( - distribution.experimental_distribute_datasets_from_function( + distribution.distribute_datasets_from_function( lambda _: get_dataset_from_tensor_slices(data))) self.assertAllEqual( @@ -401,7 +401,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, def testDistributeDatasetIteratorWithFunction(self, distribution): data = [5., 6., 7., 8.] input_iterator = iter( - distribution.experimental_distribute_datasets_from_function( + distribution.distribute_datasets_from_function( lambda _: get_dataset_from_tensor_slices(data))) @def_function.function @@ -439,7 +439,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, def testDistributeDatasetFunctionPrefetch(self, distribution): data = [5., 6., 7., 8.] input_iterator = iter( - distribution.experimental_distribute_datasets_from_function( + distribution.distribute_datasets_from_function( lambda _: get_dataset_from_tensor_slices(data))) local_results = distribution.experimental_local_results( @@ -476,10 +476,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, def testDistributeDatasetFunctionHostPrefetch(self, distribution): data = [5., 6., 7., 8.] input_iterator = iter( - distribution.experimental_distribute_datasets_from_function( + distribution.distribute_datasets_from_function( lambda _: get_dataset_from_tensor_slices(data), - distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_prefetch_to_device=False))) local_results = distribution.experimental_local_results( input_iterator.get_next()) @@ -645,7 +644,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, return dataset input_iterator = iter( - distribution.experimental_distribute_datasets_from_function(dataset_fn)) + distribution.distribute_datasets_from_function(dataset_fn)) @def_function.function def step_fn(example): @@ -673,7 +672,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, return dataset input_iterator = iter( - distribution.experimental_distribute_datasets_from_function(dataset_fn)) + distribution.distribute_datasets_from_function(dataset_fn)) @def_function.function def step_fn(example): @@ -724,7 +723,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, return dataset input_iterator = iter( - distribution.experimental_distribute_datasets_from_function(dataset_fn)) + distribution.distribute_datasets_from_function(dataset_fn)) @def_function.function def run(inputs): @@ -750,7 +749,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, return dataset input_iterator = iter( - distribution.experimental_distribute_datasets_from_function(dataset_fn)) + distribution.distribute_datasets_from_function(dataset_fn)) def embedding_lookup(inputs): embedding_weights = array_ops.zeros((1, 128)) @@ -935,7 +934,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, inputs = constant_op.constant([2., 3.]) dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) input_iterator = iter( - distribution.experimental_distribute_datasets_from_function(dataset)) + distribution.distribute_datasets_from_function(dataset)) with distribution.scope(): var = variables.Variable(1.0) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 944fc9c6ed7..14e2b6f3f02 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -114,7 +114,7 @@ the same way with eager and graph execution. devices, with a different value for each replica. They are produced by iterating through a distributed dataset returned by `tf.distribute.Strategy.experimental_distribute_dataset` and - `tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They + `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the typical result returned by `tf.distribute.Strategy.run`. @@ -688,7 +688,7 @@ class StrategyBase(object): a `tf.data.Dataset` to something that produces "per-replica" values. If you want to manually specify how the dataset should be partitioned across replicas, use - `tf.distribute.Strategy.experimental_distribute_datasets_from_function` + `tf.distribute.Strategy.distribute_datasets_from_function` instead. * Use `tf.distribute.Strategy.run` to run a function once per replica, taking values that may be "per-replica" (e.g. @@ -1030,13 +1030,13 @@ class StrategyBase(object): If the above batch splitting and dataset sharding logic is undesirable, please use - `tf.distribute.Strategy.experimental_distribute_datasets_from_function` + `tf.distribute.Strategy.distribute_datasets_from_function` instead, which does not do any automatic batching or sharding for you. Note: If you are using TPUStrategy, the order in which the data is processed by the workers when using `tf.distribute.Strategy.experimental_distribute_dataset` or - `tf.distribute.Strategy.experimental_distribute_datasets_from_function` is + `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed. This is typically required if you are using `tf.distribute` to scale prediction. You can however insert an index for each element in the batch and order outputs accordingly. Refer to [this @@ -1045,7 +1045,7 @@ class StrategyBase(object): Note: Stateful dataset transformations are currently not supported with `tf.distribute.experimental_distribute_dataset` or - `tf.distribute.experimental_distribute_datasets_from_function`. Any stateful + `tf.distribute.distribute_datasets_from_function`. Any stateful ops that the dataset may have are currently ignored. For example, if your dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, then you have a dataset graph that depends on state (i.e the random seed) on @@ -1067,8 +1067,7 @@ class StrategyBase(object): # pylint: enable=line-too-long return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access - def experimental_distribute_datasets_from_function(self, dataset_fn, - options=None): + def distribute_datasets_from_function(self, dataset_fn, options=None): # pylint: disable=line-too-long """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. @@ -1077,7 +1076,7 @@ class StrategyBase(object): instance. It is expected that the returned dataset from `dataset_fn` is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. - `tf.distribute.Strategy.experimental_distribute_datasets_from_function` does + `tf.distribute.Strategy.distribute_datasets_from_function` does not batch or shard the `tf.data.Dataset` instance returned from the input function. `dataset_fn` will be called on the CPU device of each of the workers and each generates a dataset where every @@ -1112,7 +1111,7 @@ class StrategyBase(object): Note: If you are using TPUStrategy, the order in which the data is processed by the workers when using `tf.distribute.Strategy.experimental_distribute_dataset` or - `tf.distribute.Strategy.experimental_distribute_datasets_from_function` is + `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed. This is typically required if you are using `tf.distribute` to scale prediction. You can however insert an index for each element in the batch and order outputs accordingly. Refer to [this @@ -1121,14 +1120,14 @@ class StrategyBase(object): Note: Stateful dataset transformations are currently not supported with `tf.distribute.experimental_distribute_dataset` or - `tf.distribute.experimental_distribute_datasets_from_function`. Any stateful + `tf.distribute.distribute_datasets_from_function`. Any stateful ops that the dataset may have are currently ignored. For example, if your dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, then you have a dataset graph that depends on state (i.e the random seed) on the local machine where the python process is being executed. For a tutorial on more usage and properties of this method, refer to the - [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function). + [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). Args: @@ -1141,9 +1140,17 @@ class StrategyBase(object): A `tf.distribute.DistributedDataset`. """ # pylint: enable=line-too-long - return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access + return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access dataset_fn, options) + # TODO(b/162776748): Remove deprecated symbol. + @doc_controls.do_not_doc_inheritable + @deprecation.deprecated(None, "rename to distribute_datasets_from_function") + def experimental_distribute_datasets_from_function(self, + dataset_fn, + options=None): + return self.distribute_datasets_from_function(dataset_fn, options) + def run(self, fn, args=(), kwargs=None, options=None): """Invokes `fn` on each replica, with the given arguments. @@ -1152,7 +1159,7 @@ class StrategyBase(object): have `tf.distribute.DistributedValues`, such as those produced by a `tf.distribute.DistributedDataset` from `tf.distribute.Strategy.experimental_distribute_dataset` or - `tf.distribute.Strategy.experimental_distribute_datasets_from_function`, + `tf.distribute.Strategy.distribute_datasets_from_function`, when `fn` is executed on a particular replica, it will be executed with the component of `tf.distribute.DistributedValues` that correspond to that replica. @@ -2196,8 +2203,7 @@ class StrategyExtendedV2(object): def _experimental_distribute_dataset(self, dataset, options): raise NotImplementedError("must be implemented in descendants") - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): raise NotImplementedError("must be implemented in descendants") def _experimental_distribute_values_from_function(self, value_fn): @@ -3298,8 +3304,7 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def _experimental_distribute_dataset(self, dataset, options): return dataset - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): return dataset_fn(InputContext()) def _experimental_distribute_values_from_function(self, value_fn): diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 816ff0ce465..0fe05b52d6f 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -89,8 +89,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): [distribute_lib.InputContext()], self._container_strategy()) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): return dataset_fn(distribute_lib.InputContext()) def _local_results(self, value): @@ -546,14 +545,14 @@ class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset_from_func = \ - default_strategy.experimental_distribute_datasets_from_function( + default_strategy.distribute_datasets_from_function( dataset_fn) next_val = next(iter(dist_dataset_from_func)) self.assertAllEqual([0, 1], self.evaluate(next_val)) else: dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset_from_func = \ - default_strategy.experimental_distribute_datasets_from_function( + default_strategy.distribute_datasets_from_function( dataset_fn) dataset_ops.make_initializable_iterator(dist_dataset_from_func) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index d689346870e..dbefff9cf66 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -274,7 +274,7 @@ class DistributedDatasetInterface(collections_abc.Iterable, There are two APIs to create a `tf.distribute.DistributedDataset` object: `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and - `tf.distribute.Strategy.experimental_distribute_datasets_from_function(dataset_fn)`. + `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`. *When to use which?* When you have a `tf.data.Dataset` instance, and the regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance with a new batch size that is equal to the global batch size divided by the diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index b43dc961d27..ec0b591d710 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -120,7 +120,7 @@ class DistributedIteratorTestBase(test.TestCase): split_batch_by=split_batch_by, input_context=input_context) else: - return strategy.experimental_distribute_datasets_from_function(dataset) + return strategy.distribute_datasets_from_function(dataset) def _assert_iterator_values(self, iterator, @@ -1158,8 +1158,7 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ds = distribution.experimental_distribute_dataset( dataset_fn(distribute_lib.InputContext())) else: - ds = distribution.experimental_distribute_datasets_from_function( - dataset_fn) + ds = distribution.distribute_datasets_from_function(dataset_fn) iterator = iter(ds) self.assertEqual(iterator._enable_get_next_as_optional, @@ -1208,8 +1207,7 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ds = distribution.experimental_distribute_dataset( dataset_fn(distribute_lib.InputContext())) else: - ds = distribution.experimental_distribute_datasets_from_function( - dataset_fn) + ds = distribution.distribute_datasets_from_function(dataset_fn) # Iterate through all the batches and sum them up. def sum_batch(per_replica_features): diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 691b29202e1..bc6ac811bbb 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -176,8 +176,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): dataset_fn(distribute_lib.InputContext())) type_spec = ds.element_spec else: - ds = distribution.experimental_distribute_datasets_from_function( - dataset_fn) + ds = distribution.distribute_datasets_from_function(dataset_fn) iterator = iter(ds) _check_type_spec_structure(iterator) type_spec = iterator.element_spec diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 79a563680ea..523c71c4fb5 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -507,8 +507,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return numpy_dataset.one_host_numpy_dataset( numpy_input, self._host_input_device, session) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): input_contexts = [] input_workers = self._input_workers_with_options(options) num_workers = input_workers.num_workers diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 8f40a5f7991..c256c2df78f 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -109,8 +109,10 @@ class OneDeviceStrategy(distribute_lib.Strategy): return super(OneDeviceStrategy, self).experimental_distribute_dataset( dataset, options) - def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation - options=None): + def distribute_datasets_from_function( + self, + dataset_fn, # pylint: disable=useless-super-delegation + options=None): """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. `dataset_fn` will be called once for each worker in the strategy. In this @@ -127,7 +129,7 @@ class OneDeviceStrategy(distribute_lib.Strategy): return d.shard( input_context.num_input_pipelines, input_context.input_pipeline_id) - inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) + inputs = strategy.distribute_datasets_from_function(dataset_fn) for batch in inputs: replica_results = strategy.run(replica_fn, args=(batch,)) @@ -148,9 +150,8 @@ class OneDeviceStrategy(distribute_lib.Strategy): A "distributed `Dataset`", which the caller can iterate over like regular datasets. """ - return super( - OneDeviceStrategy, self).experimental_distribute_datasets_from_function( - dataset_fn, options) + return super(OneDeviceStrategy, + self).distribute_datasets_from_function(dataset_fn, options) def experimental_local_results(self, value): # pylint: disable=useless-super-delegation """Returns the list of all local per-replica values contained in `value`. @@ -316,8 +317,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): self._input_workers_with_options(options), self._container_strategy()) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): return input_lib.get_distributed_datasets_from_function( dataset_fn, self._input_workers_with_options(options), diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 1d4c593d48b..b60ea74dd04 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -128,12 +128,10 @@ class ParameterServerStrategy(distribute_lib.Strategy): self).experimental_distribute_dataset(dataset=dataset, options=options) - def experimental_distribute_datasets_from_function(self, dataset_fn, - options=None): + def distribute_datasets_from_function(self, dataset_fn, options=None): self._raise_pss_error_if_eager() - super(ParameterServerStrategy, - self).experimental_distribute_datasets_from_function( - dataset_fn=dataset_fn, options=options) + super(ParameterServerStrategy, self).distribute_datasets_from_function( + dataset_fn=dataset_fn, options=options) def run(self, fn, args=(), kwargs=None, options=None): self._raise_pss_error_if_eager() @@ -387,8 +385,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): return numpy_dataset.one_host_numpy_dataset( numpy_input, self._input_host_device, session) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index a68183adbaa..1b4cd21c249 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -759,10 +759,9 @@ class ParameterServerStrategyTest( strategy.experimental_distribute_dataset, dataset.batch(2)) - self.assertRaisesRegex( - NotImplementedError, 'ParameterServerStrategy*', - strategy.experimental_distribute_datasets_from_function, - lambda _: dataset) + self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', + strategy.distribute_datasets_from_function, + lambda _: dataset) self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', strategy.scope) diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py index 0c9fbe0d660..5eeeb11fa8f 100644 --- a/tensorflow/python/distribute/strategy_common_test.py +++ b/tensorflow/python/distribute/strategy_common_test.py @@ -605,7 +605,7 @@ class DistributedCollectiveAllReduceStrategyTest( expected_sum_on_workers = {'chief': 10, 'worker': 35} input_iterator = iter( - strategy.experimental_distribute_datasets_from_function(dataset_fn)) + strategy.distribute_datasets_from_function(dataset_fn)) @def_function.function def run(iterator): @@ -648,7 +648,7 @@ class DistributedCollectiveAllReduceStrategyTest( input_context.input_pipeline_id) input_iterator = iter( - strategy.experimental_distribute_datasets_from_function(dataset_fn)) + strategy.distribute_datasets_from_function(dataset_fn)) @def_function.function def run(input_iterator): diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index 0bc4c6fca68..f660e3ab9f8 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -340,7 +340,7 @@ class DistributionTestBase(test.TestCase): self, strategy, input_fn, expected_values, ignore_order=False): assert_same = self.assertCountEqual if ignore_order else self.assertEqual - iterable = strategy.experimental_distribute_datasets_from_function(input_fn) + iterable = strategy.distribute_datasets_from_function(input_fn) if context.executing_eagerly(): iterator = iter(iterable) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 0d5ab9d01cd..1a3d49a2032 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -125,7 +125,7 @@ class TPUStrategyV2(distribute_lib.Strategy): `strategy.run` is called inside a `tf.function` if eager behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu. - `experimental_distribute_datasets_from_function` and + `distribute_datasets_from_function` and `experimental_distribute_dataset` APIs can be used to distribute the dataset across the TPU workers when writing your own training loop. If you are using `fit` and `compile` methods available in `tf.keras.Model`, then Keras will @@ -144,7 +144,7 @@ class TPUStrategyV2(distribute_lib.Strategy): ... y = np.random.randint(2, size=(2, 1)) ... dataset = tf.data.Dataset.from_tensor_slices((x, y)) ... return dataset.repeat().batch(1, drop_remainder=True) - >>> dist_dataset = strategy.experimental_distribute_datasets_from_function( + >>> dist_dataset = strategy.distribute_datasets_from_function( ... dataset_fn) >>> iterator = iter(dist_dataset) @@ -190,7 +190,7 @@ class TPUStrategyV2(distribute_lib.Strategy): ... # Add operation will be executed on logical device 0. ... output = strategy.experimental_assign_to_logical_device(output, 0) ... return output - >>> dist_dataset = strategy.experimental_distribute_datasets_from_function( + >>> dist_dataset = strategy.distribute_datasets_from_function( ... dataset_fn) >>> iterator = iter(dist_dataset) >>> strategy.run(step_fn, args=(next(iterator),)) @@ -229,7 +229,7 @@ class TPUStrategyV2(distribute_lib.Strategy): `tf.distribute.DistributedValues`, such as those produced by a `tf.distribute.DistributedDataset` from `tf.distribute.Strategy.experimental_distribute_dataset` or - `tf.distribute.Strategy.experimental_distribute_datasets_from_function`, + `tf.distribute.Strategy.distribute_datasets_from_function`, when `fn` is executed on a particular replica, it will be executed with the component of `tf.distribute.DistributedValues` that correspond to that replica. @@ -811,8 +811,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._container_strategy(), split_batch_by=self._num_replicas_in_sync) - def _experimental_distribute_datasets_from_function(self, dataset_fn, - options): + def _distribute_datasets_from_function(self, dataset_fn, options): input_workers = self._get_input_workers(options) input_contexts = [] num_workers = input_workers.num_workers diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index c2aa68a0785..70f05a3bb9c 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -509,10 +509,9 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return dataset.map(make_sparse) dataset = iter( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn, - distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_prefetch_to_device=False))) sparse, result = sparse_lookup(dataset) @@ -560,10 +559,9 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return dataset.map(make_sparse) dataset = iter( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn, - distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_prefetch_to_device=False))) output = sparse_lookup(dataset) @@ -616,7 +614,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return dataset.map(make_sparse) dataset = iter( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False))) @@ -730,7 +728,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase): return dataset.batch(strategy.num_replicas_in_sync) with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): - iter(strategy.experimental_distribute_datasets_from_function(dataset_fn)) + iter(strategy.distribute_datasets_from_function(dataset_fn)) def test_prefetch_to_device_ragged_dataset_fn(self): strategy = get_tpu_strategy() @@ -745,7 +743,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase): return dataset.batch(strategy.num_replicas_in_sync) with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): - iter(strategy.experimental_distribute_datasets_from_function(dataset_fn)) + iter(strategy.distribute_datasets_from_function(dataset_fn)) class TPUStrategyDistributionTest( diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 7825c732356..18485749b18 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1732,10 +1732,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, ds = ds.batch(5).repeat() return ds - ds = distribution.experimental_distribute_datasets_from_function( - make_dataset) - val_ds = distribution.experimental_distribute_datasets_from_function( - make_dataset) + ds = distribution.distribute_datasets_from_function(make_dataset) + val_ds = distribution.distribute_datasets_from_function(make_dataset) model.fit( ds, diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index ededc63e860..1417824e53c 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -152,7 +152,7 @@ class TPUEmbedding(tracking.AutoTrackable): ```python distributed_dataset = ( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) @@ -583,7 +583,7 @@ class TPUEmbedding(tracking.AutoTrackable): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) @@ -680,7 +680,7 @@ class TPUEmbedding(tracking.AutoTrackable): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) @@ -1125,7 +1125,7 @@ class TPUEmbedding(tracking.AutoTrackable): features matches the per core batch size. This will automatically happen if your input dataset is batched to the global batch size and you use `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` - or if you use `experimental_distribute_datasets_from_function` and batch + or if you use `distribute_datasets_from_function` and batch to the per core batch size computed by the context passed to your input function. @@ -1135,7 +1135,7 @@ class TPUEmbedding(tracking.AutoTrackable): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( - strategy.experimental_distribute_datasets_from_function( + strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py index 7a9a727d956..8960d907be7 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py @@ -427,7 +427,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') input_fn = self._create_dense_input_fn(strategy) - dist = strategy.experimental_distribute_datasets_from_function( + dist = strategy.distribute_datasets_from_function( input_fn, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False)) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py index 5e081d6f9ef..4ad26ce5742 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py @@ -556,7 +556,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') input_fn = self._create_dense_input_fn(strategy, include_weights=True) - dist = strategy.experimental_distribute_datasets_from_function( + dist = strategy.distribute_datasets_from_function( input_fn, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False)) @@ -744,8 +744,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') input_fn = self._create_dense_input_fn(strategy) - sparse_iter = iter(strategy.experimental_distribute_datasets_from_function( - input_fn)) + sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn)) @def_function.function def test_fn(): @@ -768,8 +767,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') input_fn = self._create_dense_input_fn(strategy) - sparse_iter = iter(strategy.experimental_distribute_datasets_from_function( - input_fn)) + sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn)) @def_function.function def test_fn(): @@ -1127,7 +1125,8 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): def input_fn(ctx): del ctx return dataset_ops.DatasetV2.from_tensors(feature).repeat() - dist = strategy.experimental_distribute_datasets_from_function( + + dist = strategy.distribute_datasets_from_function( input_fn, options=distribute_lib.InputOptions( experimental_prefetch_to_device=False)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt index 85dd7f5eaa6..009e612f94f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt index 23e03ceab02..e6729e82004 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt index 7fbd9dded22..00964c7cbb0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt @@ -27,6 +27,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index 2f7ba2db15c..b360ed1f628 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index dac5652c7fd..598a9dbc15b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index f63c16dec5a..0e78cb93209 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index 53d5b756568..4cbcab923c6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -32,6 +32,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 5e17d5c8752..598c01c6da0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index aa23fddab08..5b044576d14 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 611247e3ab9..a8ebaa87590 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -27,6 +27,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt index 505c77be2e2..3a94e8b9fa4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_assign_to_logical_device" argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index 7fd7878c45c..290d834305a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index 66e1fb34bb5..2245eb4b122 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index 3d8265ee720..707d7281a5c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index 8956d528b3a..a2adcea87e6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_distribute_dataset" argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "