Rename "experimental_distribute_datasets_from_function" to "distribute_datasets_from_function".
PiperOrigin-RevId: 333413940 Change-Id: I06f79b8b9bb1445e2b890e4c49138e7463e37a5e
This commit is contained in:
parent
faf6a15ff4
commit
406ca6ad78
@ -58,6 +58,9 @@
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy.
|
||||
* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models.
|
||||
|
||||
* `tf.distribute`:
|
||||
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
|
@ -105,52 +105,6 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
return super(CentralStorageStrategy, self).experimental_distribute_dataset(
|
||||
dataset, options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
|
||||
options=None):
|
||||
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
||||
|
||||
`dataset_fn` will be called once for each worker in the strategy. In this
|
||||
case, we only have one worker so `dataset_fn` is called once. Each replica
|
||||
on this worker will then dequeue a batch of elements from this local
|
||||
dataset.
|
||||
|
||||
The `dataset_fn` should take an `tf.distribute.InputContext` instance where
|
||||
information about batching and input replication can be accessed.
|
||||
|
||||
For Example:
|
||||
```
|
||||
def dataset_fn(input_context):
|
||||
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
|
||||
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
|
||||
return d.shard(
|
||||
input_context.num_input_pipelines, input_context.input_pipeline_id)
|
||||
|
||||
inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
for batch in inputs:
|
||||
replica_results = strategy.run(replica_fn, args=(batch,))
|
||||
```
|
||||
|
||||
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
|
||||
per-replica batch size, unlike `experimental_distribute_dataset`, which uses
|
||||
the global batch size. This may be computed using
|
||||
`input_context.get_per_replica_batch_size`.
|
||||
|
||||
Args:
|
||||
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
|
||||
returning a `tf.data.Dataset`.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`", which the caller can iterate over like regular
|
||||
datasets.
|
||||
"""
|
||||
return super(
|
||||
CentralStorageStrategy,
|
||||
self).experimental_distribute_datasets_from_function(dataset_fn,
|
||||
options)
|
||||
|
||||
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
|
||||
|
@ -479,8 +479,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
split_batch_by=self._num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn=dataset_fn,
|
||||
|
@ -386,7 +386,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
def testDistributeDatasetIteratorWithoutFunction(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
distribution.distribute_datasets_from_function(
|
||||
lambda _: get_dataset_from_tensor_slices(data)))
|
||||
|
||||
self.assertAllEqual(
|
||||
@ -401,7 +401,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
def testDistributeDatasetIteratorWithFunction(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
distribution.distribute_datasets_from_function(
|
||||
lambda _: get_dataset_from_tensor_slices(data)))
|
||||
|
||||
@def_function.function
|
||||
@ -439,7 +439,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
def testDistributeDatasetFunctionPrefetch(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
distribution.distribute_datasets_from_function(
|
||||
lambda _: get_dataset_from_tensor_slices(data)))
|
||||
|
||||
local_results = distribution.experimental_local_results(
|
||||
@ -476,10 +476,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
def testDistributeDatasetFunctionHostPrefetch(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
distribution.distribute_datasets_from_function(
|
||||
lambda _: get_dataset_from_tensor_slices(data),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
|
||||
|
||||
local_results = distribution.experimental_local_results(
|
||||
input_iterator.get_next())
|
||||
@ -645,7 +644,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
distribution.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def step_fn(example):
|
||||
@ -673,7 +672,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
distribution.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def step_fn(example):
|
||||
@ -724,7 +723,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
distribution.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def run(inputs):
|
||||
@ -750,7 +749,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
distribution.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
def embedding_lookup(inputs):
|
||||
embedding_weights = array_ops.zeros((1, 128))
|
||||
@ -935,7 +934,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
inputs = constant_op.constant([2., 3.])
|
||||
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset))
|
||||
distribution.distribute_datasets_from_function(dataset))
|
||||
with distribution.scope():
|
||||
var = variables.Variable(1.0)
|
||||
|
||||
|
@ -114,7 +114,7 @@ the same way with eager and graph execution.
|
||||
devices, with a different value for each replica. They are produced by
|
||||
iterating through a distributed dataset returned by
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` and
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function`. They
|
||||
are also the typical result returned by
|
||||
`tf.distribute.Strategy.run`.
|
||||
|
||||
@ -688,7 +688,7 @@ class StrategyBase(object):
|
||||
a `tf.data.Dataset` to something that produces "per-replica" values.
|
||||
If you want to manually specify how the dataset should be partitioned
|
||||
across replicas, use
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function`
|
||||
instead.
|
||||
* Use `tf.distribute.Strategy.run` to run a function
|
||||
once per replica, taking values that may be "per-replica" (e.g.
|
||||
@ -1030,13 +1030,13 @@ class StrategyBase(object):
|
||||
|
||||
If the above batch splitting and dataset sharding logic is undesirable,
|
||||
please use
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function`
|
||||
instead, which does not do any automatic batching or sharding for you.
|
||||
|
||||
Note: If you are using TPUStrategy, the order in which the data is processed
|
||||
by the workers when using
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` is
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function` is
|
||||
not guaranteed. This is typically required if you are using
|
||||
`tf.distribute` to scale prediction. You can however insert an index for
|
||||
each element in the batch and order outputs accordingly. Refer to [this
|
||||
@ -1045,7 +1045,7 @@ class StrategyBase(object):
|
||||
|
||||
Note: Stateful dataset transformations are currently not supported with
|
||||
`tf.distribute.experimental_distribute_dataset` or
|
||||
`tf.distribute.experimental_distribute_datasets_from_function`. Any stateful
|
||||
`tf.distribute.distribute_datasets_from_function`. Any stateful
|
||||
ops that the dataset may have are currently ignored. For example, if your
|
||||
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
|
||||
then you have a dataset graph that depends on state (i.e the random seed) on
|
||||
@ -1067,8 +1067,7 @@ class StrategyBase(object):
|
||||
# pylint: enable=line-too-long
|
||||
return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options=None):
|
||||
def distribute_datasets_from_function(self, dataset_fn, options=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
||||
|
||||
@ -1077,7 +1076,7 @@ class StrategyBase(object):
|
||||
instance. It is expected that the returned dataset from `dataset_fn` is
|
||||
already batched by per-replica batch size (i.e. global batch size divided by
|
||||
the number of replicas in sync) and sharded.
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` does
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function` does
|
||||
not batch or shard the `tf.data.Dataset` instance
|
||||
returned from the input function. `dataset_fn` will be called on the CPU
|
||||
device of each of the workers and each generates a dataset where every
|
||||
@ -1112,7 +1111,7 @@ class StrategyBase(object):
|
||||
Note: If you are using TPUStrategy, the order in which the data is processed
|
||||
by the workers when using
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` is
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function` is
|
||||
not guaranteed. This is typically required if you are using
|
||||
`tf.distribute` to scale prediction. You can however insert an index for
|
||||
each element in the batch and order outputs accordingly. Refer to [this
|
||||
@ -1121,14 +1120,14 @@ class StrategyBase(object):
|
||||
|
||||
Note: Stateful dataset transformations are currently not supported with
|
||||
`tf.distribute.experimental_distribute_dataset` or
|
||||
`tf.distribute.experimental_distribute_datasets_from_function`. Any stateful
|
||||
`tf.distribute.distribute_datasets_from_function`. Any stateful
|
||||
ops that the dataset may have are currently ignored. For example, if your
|
||||
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
|
||||
then you have a dataset graph that depends on state (i.e the random seed) on
|
||||
the local machine where the python process is being executed.
|
||||
|
||||
For a tutorial on more usage and properties of this method, refer to the
|
||||
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function).
|
||||
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
|
||||
If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
|
||||
|
||||
Args:
|
||||
@ -1141,9 +1140,17 @@ class StrategyBase(object):
|
||||
A `tf.distribute.DistributedDataset`.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
|
||||
return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access
|
||||
dataset_fn, options)
|
||||
|
||||
# TODO(b/162776748): Remove deprecated symbol.
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
@deprecation.deprecated(None, "rename to distribute_datasets_from_function")
|
||||
def experimental_distribute_datasets_from_function(self,
|
||||
dataset_fn,
|
||||
options=None):
|
||||
return self.distribute_datasets_from_function(dataset_fn, options)
|
||||
|
||||
def run(self, fn, args=(), kwargs=None, options=None):
|
||||
"""Invokes `fn` on each replica, with the given arguments.
|
||||
|
||||
@ -1152,7 +1159,7 @@ class StrategyBase(object):
|
||||
have `tf.distribute.DistributedValues`, such as those produced by a
|
||||
`tf.distribute.DistributedDataset` from
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function`,
|
||||
when `fn` is executed on a particular replica, it will be executed with the
|
||||
component of `tf.distribute.DistributedValues` that correspond to that
|
||||
replica.
|
||||
@ -2196,8 +2203,7 @@ class StrategyExtendedV2(object):
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
@ -3298,8 +3304,7 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
return dataset
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
return dataset_fn(InputContext())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
|
@ -89,8 +89,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
||||
[distribute_lib.InputContext()],
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
return dataset_fn(distribute_lib.InputContext())
|
||||
|
||||
def _local_results(self, value):
|
||||
@ -546,14 +545,14 @@ class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
if context.executing_eagerly():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
||||
dist_dataset_from_func = \
|
||||
default_strategy.experimental_distribute_datasets_from_function(
|
||||
default_strategy.distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
next_val = next(iter(dist_dataset_from_func))
|
||||
self.assertAllEqual([0, 1], self.evaluate(next_val))
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
||||
dist_dataset_from_func = \
|
||||
default_strategy.experimental_distribute_datasets_from_function(
|
||||
default_strategy.distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
dataset_ops.make_initializable_iterator(dist_dataset_from_func)
|
||||
|
||||
|
@ -274,7 +274,7 @@ class DistributedDatasetInterface(collections_abc.Iterable,
|
||||
|
||||
There are two APIs to create a `tf.distribute.DistributedDataset` object:
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function(dataset_fn)`.
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`.
|
||||
*When to use which?* When you have a `tf.data.Dataset` instance, and the
|
||||
regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
|
||||
with a new batch size that is equal to the global batch size divided by the
|
||||
|
@ -120,7 +120,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context)
|
||||
else:
|
||||
return strategy.experimental_distribute_datasets_from_function(dataset)
|
||||
return strategy.distribute_datasets_from_function(dataset)
|
||||
|
||||
def _assert_iterator_values(self,
|
||||
iterator,
|
||||
@ -1158,8 +1158,7 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||
iterator = iter(ds)
|
||||
|
||||
self.assertEqual(iterator._enable_get_next_as_optional,
|
||||
@ -1208,8 +1207,7 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
# Iterate through all the batches and sum them up.
|
||||
def sum_batch(per_replica_features):
|
||||
|
@ -176,8 +176,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
type_spec = ds.element_spec
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||
iterator = iter(ds)
|
||||
_check_type_spec_structure(iterator)
|
||||
type_spec = iterator.element_spec
|
||||
|
@ -507,8 +507,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return numpy_dataset.one_host_numpy_dataset(
|
||||
numpy_input, self._host_input_device, session)
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
input_contexts = []
|
||||
input_workers = self._input_workers_with_options(options)
|
||||
num_workers = input_workers.num_workers
|
||||
|
@ -109,8 +109,10 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
return super(OneDeviceStrategy, self).experimental_distribute_dataset(
|
||||
dataset, options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
|
||||
options=None):
|
||||
def distribute_datasets_from_function(
|
||||
self,
|
||||
dataset_fn, # pylint: disable=useless-super-delegation
|
||||
options=None):
|
||||
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
||||
|
||||
`dataset_fn` will be called once for each worker in the strategy. In this
|
||||
@ -127,7 +129,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
return d.shard(
|
||||
input_context.num_input_pipelines, input_context.input_pipeline_id)
|
||||
|
||||
inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
|
||||
inputs = strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
for batch in inputs:
|
||||
replica_results = strategy.run(replica_fn, args=(batch,))
|
||||
@ -148,9 +150,8 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
A "distributed `Dataset`", which the caller can iterate over like regular
|
||||
datasets.
|
||||
"""
|
||||
return super(
|
||||
OneDeviceStrategy, self).experimental_distribute_datasets_from_function(
|
||||
dataset_fn, options)
|
||||
return super(OneDeviceStrategy,
|
||||
self).distribute_datasets_from_function(dataset_fn, options)
|
||||
|
||||
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
@ -316,8 +317,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
self._input_workers_with_options(options),
|
||||
|
@ -128,12 +128,10 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
self).experimental_distribute_dataset(dataset=dataset,
|
||||
options=options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options=None):
|
||||
def distribute_datasets_from_function(self, dataset_fn, options=None):
|
||||
self._raise_pss_error_if_eager()
|
||||
super(ParameterServerStrategy,
|
||||
self).experimental_distribute_datasets_from_function(
|
||||
dataset_fn=dataset_fn, options=options)
|
||||
super(ParameterServerStrategy, self).distribute_datasets_from_function(
|
||||
dataset_fn=dataset_fn, options=options)
|
||||
|
||||
def run(self, fn, args=(), kwargs=None, options=None):
|
||||
self._raise_pss_error_if_eager()
|
||||
@ -387,8 +385,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
return numpy_dataset.one_host_numpy_dataset(
|
||||
numpy_input, self._input_host_device, session)
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
if self._cluster_spec:
|
||||
input_pipeline_id = multi_worker_util.id_in_cluster(
|
||||
self._cluster_spec, self._task_type, self._task_id)
|
||||
|
@ -759,10 +759,9 @@ class ParameterServerStrategyTest(
|
||||
strategy.experimental_distribute_dataset,
|
||||
dataset.batch(2))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
NotImplementedError, 'ParameterServerStrategy*',
|
||||
strategy.experimental_distribute_datasets_from_function,
|
||||
lambda _: dataset)
|
||||
self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
|
||||
strategy.distribute_datasets_from_function,
|
||||
lambda _: dataset)
|
||||
|
||||
self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
|
||||
strategy.scope)
|
||||
|
@ -605,7 +605,7 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
|
||||
expected_sum_on_workers = {'chief': 10, 'worker': 35}
|
||||
input_iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
strategy.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def run(iterator):
|
||||
@ -648,7 +648,7 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
input_iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
strategy.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def run(input_iterator):
|
||||
|
@ -340,7 +340,7 @@ class DistributionTestBase(test.TestCase):
|
||||
self, strategy, input_fn, expected_values, ignore_order=False):
|
||||
assert_same = self.assertCountEqual if ignore_order else self.assertEqual
|
||||
|
||||
iterable = strategy.experimental_distribute_datasets_from_function(input_fn)
|
||||
iterable = strategy.distribute_datasets_from_function(input_fn)
|
||||
if context.executing_eagerly():
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TPUStrategyV2(distribute_lib.Strategy):
|
||||
`strategy.run` is called inside a `tf.function` if eager
|
||||
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
|
||||
|
||||
`experimental_distribute_datasets_from_function` and
|
||||
`distribute_datasets_from_function` and
|
||||
`experimental_distribute_dataset` APIs can be used to distribute the dataset
|
||||
across the TPU workers when writing your own training loop. If you are using
|
||||
`fit` and `compile` methods available in `tf.keras.Model`, then Keras will
|
||||
@ -144,7 +144,7 @@ class TPUStrategyV2(distribute_lib.Strategy):
|
||||
... y = np.random.randint(2, size=(2, 1))
|
||||
... dataset = tf.data.Dataset.from_tensor_slices((x, y))
|
||||
... return dataset.repeat().batch(1, drop_remainder=True)
|
||||
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
|
||||
>>> dist_dataset = strategy.distribute_datasets_from_function(
|
||||
... dataset_fn)
|
||||
>>> iterator = iter(dist_dataset)
|
||||
|
||||
@ -190,7 +190,7 @@ class TPUStrategyV2(distribute_lib.Strategy):
|
||||
... # Add operation will be executed on logical device 0.
|
||||
... output = strategy.experimental_assign_to_logical_device(output, 0)
|
||||
... return output
|
||||
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
|
||||
>>> dist_dataset = strategy.distribute_datasets_from_function(
|
||||
... dataset_fn)
|
||||
>>> iterator = iter(dist_dataset)
|
||||
>>> strategy.run(step_fn, args=(next(iterator),))
|
||||
@ -229,7 +229,7 @@ class TPUStrategyV2(distribute_lib.Strategy):
|
||||
`tf.distribute.DistributedValues`, such as those produced by a
|
||||
`tf.distribute.DistributedDataset` from
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
|
||||
`tf.distribute.Strategy.distribute_datasets_from_function`,
|
||||
when `fn` is executed on a particular replica, it will be executed with the
|
||||
component of `tf.distribute.DistributedValues` that correspond to that
|
||||
replica.
|
||||
@ -811,8 +811,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
input_workers = self._get_input_workers(options)
|
||||
input_contexts = []
|
||||
num_workers = input_workers.num_workers
|
||||
|
@ -509,10 +509,9 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
return dataset.map(make_sparse)
|
||||
|
||||
dataset = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn,
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
|
||||
|
||||
sparse, result = sparse_lookup(dataset)
|
||||
|
||||
@ -560,10 +559,9 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
return dataset.map(make_sparse)
|
||||
|
||||
dataset = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn,
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
|
||||
|
||||
output = sparse_lookup(dataset)
|
||||
|
||||
@ -616,7 +614,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
return dataset.map(make_sparse)
|
||||
|
||||
dataset = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
@ -730,7 +728,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||
return dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
iter(strategy.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
def test_prefetch_to_device_ragged_dataset_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
@ -745,7 +743,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||
return dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
iter(strategy.distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
|
||||
class TPUStrategyDistributionTest(
|
||||
|
@ -1732,10 +1732,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
ds = ds.batch(5).repeat()
|
||||
return ds
|
||||
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
make_dataset)
|
||||
val_ds = distribution.experimental_distribute_datasets_from_function(
|
||||
make_dataset)
|
||||
ds = distribution.distribute_datasets_from_function(make_dataset)
|
||||
val_ds = distribution.distribute_datasets_from_function(make_dataset)
|
||||
|
||||
model.fit(
|
||||
ds,
|
||||
|
@ -152,7 +152,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
|
||||
```python
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
@ -583,7 +583,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
@ -680,7 +680,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
@ -1125,7 +1125,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
features matches the per core batch size. This will automatically happen if
|
||||
your input dataset is batched to the global batch size and you use
|
||||
`tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
|
||||
or if you use `experimental_distribute_datasets_from_function` and batch
|
||||
or if you use `distribute_datasets_from_function` and batch
|
||||
to the per core batch size computed by the context passed to your input
|
||||
function.
|
||||
|
||||
@ -1135,7 +1135,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
strategy.distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
|
@ -427,7 +427,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
dist = strategy.distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
|
@ -556,7 +556,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy, include_weights=True)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
dist = strategy.distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
@ -744,8 +744,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn))
|
||||
sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -768,8 +767,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn))
|
||||
sparse_iter = iter(strategy.distribute_datasets_from_function(input_fn))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -1127,7 +1125,8 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def input_fn(ctx):
|
||||
del ctx
|
||||
return dataset_ops.DatasetV2.from_tensors(feature).repeat()
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
|
||||
dist = strategy.distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -27,6 +27,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -32,6 +32,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -27,6 +27,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user