Add tf.distribute.Strategy.experimental_distribute_dataset to clone the dataset on different devices and return a DistributedDataset. You can iterate on this dataset in eager mode in a pythonic fashion since we implement the __iter__ protocol for the wrapped dataset.

PiperOrigin-RevId: 243459250
This commit is contained in:
Anjali Sridhar 2019-04-13 19:01:24 -07:00 committed by TensorFlower Gardener
parent fb695b89e0
commit 50eee22d86
24 changed files with 681 additions and 177 deletions

View File

@ -764,6 +764,24 @@ cuda_py_test(
],
)
distribute_py_test(
name = "custom_training_loop_test",
srcs = ["custom_training_loop_test.py"],
main = "custom_training_loop_test.py",
tags = [
"multi_and_single_gpu",
],
deps = [
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
],
)
distribute_py_test(
name = "minimize_loss_test",
srcs = ["minimize_loss_test.py"],

View File

@ -350,6 +350,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
num_replicas_in_sync=self._num_replicas_in_sync)
return input_context
def _experimental_distribute_dataset(self, dataset):
input_context = self._make_input_context()
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._num_replicas_in_sync,
input_context=input_context)
def _make_dataset_iterator(self, dataset):
"""Distributes the dataset to each local GPU."""
input_context = self._make_input_context()

View File

@ -0,0 +1,114 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for custom training loops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
class InputIterationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu,
mode=["eager"]
))
def testFullEager(self, distribution):
dataset = self._get_dataset()
def train_step(data):
return data
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:
output = distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,)))
results.append(output)
self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu,
mode=["eager"]
))
def testStepInFunction(self, distribution):
dataset = self._get_dataset()
@def_function.function
def train_step(data):
return data
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:
output = distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,)))
results.append(output)
self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu +
[strategy_combinations.tpu_strategy_one_step],
mode=["eager"]
))
def testRunInFunction(self, distribution):
dataset = self._get_dataset()
def train_step(data):
return data
@def_function.function
def f_train_step(input_data):
return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:
output = f_train_step(x)
results.append(output)
self._validate_outputs(results)
def _get_dataset(self):
if tf2.enabled():
return dataset_ops.DatasetV2.range(10).batch(2)
else:
return dataset_ops.Dataset.range(10).batch(2)
def _validate_outputs(self, actual_results):
expected_results = [[i, i+1] for i in range(0, 10, 2)]
self.assertEqual(len(expected_results), len(actual_results))
for i, expected_result in enumerate(expected_results):
final_result = []
actual_result = actual_results[i]
for val in actual_result:
final_result.extend(val.numpy())
self.assertAllEqual(expected_result, final_result)
if __name__ == "__main__":
test.main()

View File

@ -377,6 +377,39 @@ class Strategy(object):
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.experimental_run_v2(fn, args=args)
def experimental_distribute_dataset(self, dataset):
"""Distributes a tf.data.Dataset instance provided via `dataset`.
Data from the given dataset will be distributed evenly across all the
compute replicas. This function assumes that the input dataset is batched
by the global batch size.
The following is an example:
```python
strategy = tf.distribute.MirroredStrategy()
# Create a dataset
dataset = dataset_ops.Dataset.range(10).batch(2)
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
# Iterate over the distributed dataset
for x in dist_dataset:
# process dataset elements
strategy.experimental_run_v2(train_step, args=(x,))
```
Args:
dataset: `tf.data.Dataset` that will be distributed evenly across all
replicas.
Returns:
A `DistributedDataset` which returns inputs for each step of the
computation.
"""
return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""Runs ops in `fn` on each replica, with the given arguments.
@ -1062,6 +1095,9 @@ class StrategyExtendedV2(object):
def _make_input_fn_iterator(self, input_fn, replication_mode):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_dataset(self, dataset):
raise NotImplementedError("must be implemented in descendants")
def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy.
return self._local_results(
@ -1671,6 +1707,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def variable_created_in_scope(self, v):
return v._distribute_strategy is None # pylint: disable=protected-access
def _experimental_distribute_dataset(self, dataset):
return dataset
def _make_dataset_iterator(self, dataset):
return _DefaultDistributionExtended.DefaultInputIterator(dataset)

View File

@ -29,6 +29,7 @@ from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@ -38,6 +39,35 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
def get_distributed_dataset(dataset, input_workers, split_batch_by=None,
input_context=None):
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
This is a common function that is used by all strategies to return the right
tf.data.Dataset wrapped instance depending on the `dataset` argument type.
Args:
dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
Returns:
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
"""
if isinstance(dataset, dataset_ops.DatasetV1):
return DistributedDatasetV1(dataset, input_workers, split_batch_by,
input_context)
else:
return DistributedDataset(dataset, input_workers, split_batch_by,
input_context)
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
@ -95,31 +125,7 @@ class InputWorkers(object):
self.__class__.__name__, debug_repr, self._device_map)
class InputIterator(object):
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
def get_next(self):
"""Returns the next inputs for all replicas."""
raise NotImplementedError("must be implemented in descendants")
def initialize(self):
"""Initialize the underlying input dataset, when applicable.
In eager mode, this will create a new iterator and return it.
In graph mode, this will initialize the same underlying iterator(s).
Users are required to call this if
- This iterator was returned from a call to `make_input_fn_iterator` with an
input function that returns a dataset.
- Or this iterator was returned from a call to `make_dataset_iterator`.
Returns:
A list of initialization ops to be executed.
"""
raise NotImplementedError("must be implemented in descendants")
class InputIteratorImpl(InputIterator):
class DistributedIterator(object):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators, **kwargs):
@ -127,11 +133,11 @@ class InputIteratorImpl(InputIterator):
# be correctly handled.
self._enable_get_next_as_optional = False
if len(kwargs) > 1:
raise ValueError("InputIteratorImpl constructor only takes one "
raise ValueError("DistributedIterator constructor only takes one "
"experimental flag now")
if len(kwargs) == 1:
if "_enable_get_next_as_optional" not in kwargs:
raise ValueError("InputIteratorImpl constructor does not support "
raise ValueError("DistributedIterator constructor does not support "
"arguments: {}".format(kwargs))
self._enable_get_next_as_optional = (
kwargs["_enable_get_next_as_optional"])
@ -143,6 +149,18 @@ class InputIteratorImpl(InputIterator):
self._iterators = iterators
self._input_workers = input_workers
def next(self):
return self.__next__()
def __next__(self):
if not context.executing_eagerly():
raise RuntimeError("__iter__ is only supported "
"when eager execution is enabled.")
try:
return self.get_next()
except errors.OutOfRangeError:
raise StopIteration
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
if not self._enable_get_next_as_optional:
@ -226,16 +244,33 @@ class InputIteratorImpl(InputIterator):
return values.regroup(self._input_workers.device_map, replicas)
# We need a private initializer method for re-initializing multidevice
# iterators when used with Keras training loops. If we don't reinitialize the
# iterator we run into memory leak issues (b/123315763).
@property
def _initializer(self):
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return control_flow_ops.group(init_ops)
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for tf.data.DatasetV1."""
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
# with tf.data iterator APIs.
def initialize(self):
"""Initialze underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return init_ops
return super(DistributedIteratorV1, self)._initializer
@property
def initializer(self):
return self.initialize()
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
@ -260,7 +295,113 @@ class InputIteratorImpl(InputIterator):
return None
class InputFunctionIterator(InputIteratorImpl):
class DistributedDataset(object):
"""Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
def __init__(self, dataset, input_workers, split_batch_by=None,
input_context=None, **kwargs):
"""Distribute the dataset on all workers.
If `split_batch_by` is not None, we "split" each batch of the dataset by
`split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
**kwargs: Additional experimental flags. Will be removed in future.
"""
# We clone and shard the dataset on each worker. The current setup tries to
# shard the dataset by files if possible so that each worker sees a
# different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset.
assert isinstance(input_workers, InputWorkers)
if split_batch_by:
dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access
self._cloned_datasets = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
dataset, input_context.num_input_pipelines,
input_context.input_pipeline_id)
self._cloned_datasets.append(dataset)
else:
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
cloned_dataset = dataset
if not context.executing_eagerly():
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
cloned_dataset = cloned_dataset.with_options(dataset.options())
# TODO(b/129506833): Figure out between graph cases
cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
cloned_dataset, len(input_workers.worker_devices), i)
self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers
# TODO(anjalisridhar): Identify if we need to set this property on the
# iterator.
self._element_structure = dataset._element_structure # pylint: disable=protected-access
self._kwargs = kwargs
def __iter__(self):
# TODO(anjalisridhar): Remove this restriction once we can create
# iterators in graph mode.
if context.executing_eagerly():
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
iterator = DistributedIterator(self._input_workers, worker_iterators,
**self._kwargs)
iterator._element_structure = self._element_structure # pylint: disable=protected-access
return iterator
else:
raise RuntimeError("__iter__ is only supported when eager "
"execution is enabled.")
class DistributedDatasetV1(DistributedDataset):
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
def __init__(self, dataset, input_workers, split_batch_by=None,
input_context=None, **kwargs):
self._input_workers = input_workers
super(DistributedDatasetV1, self).__init__(dataset, input_workers,
split_batch_by=split_batch_by,
input_context=input_context,
**kwargs)
def make_one_shot_iterator(self):
"""Get a one time use iterator for DistributedDatasetV1."""
return self._get_iterator()
def make_initializable_iterator(self):
"""Get an initializable iterator for DistributedDatasetV1."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `make_one_shot_iterator` instead.")
return self._get_iterator()
def _get_iterator(self):
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
**self._kwargs)
iterator._element_structure = self._element_structure # pylint: disable=protected-access
return iterator
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
# APIs.
class InputFunctionIterator(DistributedIteratorV1):
"""Iterator created from input function."""
def __init__(self, input_fn, input_workers, input_contexts, **kwargs):
@ -305,7 +446,9 @@ class InputFunctionIterator(InputIteratorImpl):
input_workers, iterators, **kwargs)
class DatasetIterator(InputIteratorImpl):
# TODO(anjalisridhar): This class will soon be removed and users should move
# to using DistributedIterator.
class DatasetIterator(DistributedIteratorV1):
"""Iterator created from input dataset."""
def __init__(self, dataset, input_workers, split_batch_by=None,
@ -313,20 +456,7 @@ class DatasetIterator(InputIteratorImpl):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value. To achieve this, we first unbatch the
input dataset and then rebatch it with the per replica batch size that is
calculated using `global_batch_size // split_batch_by`.
The currently supported datasets are as follows:
`dataset.batch()` is the last operation on the dataset OR
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
We clone and shard the dataset on each worker. The current setup tries to
shard the dataset by files if possible so that each worker sees a different
subset of files. If that is not possible, will attempt to shard the final
input such that each worker will run the entire preprocessing pipeline and
only receive its own shard of the dataset.
dataset by `split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
@ -339,40 +469,14 @@ class DatasetIterator(InputIteratorImpl):
`num_input_pipelines` in the `InputContext`.
**kwargs: Additional experimental flags. Will be removed in future.
"""
assert isinstance(input_workers, InputWorkers)
if split_batch_by:
dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access
iterators = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
worker = input_workers.worker_devices[0]
worker_devices = input_workers.compute_devices_for_worker(0)
dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
dataset, input_context.num_input_pipelines,
input_context.input_pipeline_id)
iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices)
iterators.append(iterator)
else:
# In-graph cases where we depend on the list of workers for sharding
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
cloned_dataset = dataset
if not context.executing_eagerly():
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
cloned_dataset = cloned_dataset.with_options(dataset.options())
cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
cloned_dataset, len(input_workers.worker_devices), i)
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
worker_devices)
iterators.append(iterator)
self._element_structure = dataset._element_structure # pylint: disable=protected-access
super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
dist_dataset = DistributedDatasetV1(dataset, input_workers,
split_batch_by=split_batch_by,
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(input_workers, worker_iterators, # pylint: disable=protected-access
**kwargs)
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access
def _dummy_tensor_fn(value_structure):
@ -554,6 +658,21 @@ class _SingleWorkerCallableIterator(object):
return []
def _create_iterators_per_worker(worker_datasets, input_workers):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
iterators.append(iterator)
return iterators
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
"""Get the batched dataset from `d`."""

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
@ -32,14 +33,10 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import nest
class InputIteratorTestBase(test.TestCase):
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
devices, split_batch_by,
enable_get_next_as_optional):
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
class DistributedIteratorTestBase(test.TestCase):
def _wrap_iterator(self, input_type, dataset_fn, input_workers, devices,
split_batch_by, enable_get_next_as_optional):
if input_type == "input_fn":
input_contexts = []
for i in range(input_workers.num_workers):
@ -59,116 +56,223 @@ class InputIteratorTestBase(test.TestCase):
_enable_get_next_as_optional=enable_get_next_as_optional)
return iterator
def _test_iterator(self,
input_type,
dataset_fn,
worker_device_pairs,
expected_values,
sess=None,
split_batch_by=None,
enable_get_next_as_optional=False):
def _wrap_dataset(self, input_type, dataset, input_workers,
split_batch_by, enable_get_next_as_optional):
if isinstance(dataset, dataset_ops.Dataset):
return input_lib.DistributedDatasetV1(
dataset, input_workers,
split_batch_by,
_enable_get_next_as_optional=enable_get_next_as_optional)
else:
return input_lib.DistributedDataset(
dataset, input_workers,
split_batch_by,
_enable_get_next_as_optional=enable_get_next_as_optional)
def _test_input_iteration(self,
input_type,
api_type,
iteration_type,
dataset_fn,
worker_device_pairs,
expected_values,
sess=None,
split_batch_by=None,
enable_get_next_as_optional=False):
if iteration_type == "for_loop" and not context.executing_eagerly():
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_dataset" and input_type == "input_fn":
self.skipTest("unsupported test combination.")
devices = nest.flatten([ds for _, ds in worker_device_pairs])
iterator = self._create_iterator(
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by,
enable_get_next_as_optional)
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize()))
if api_type == "wrap_into_iterator":
iterator = self._wrap_iterator(
input_type, dataset_fn, input_workers, devices, split_batch_by,
enable_get_next_as_optional)
else:
# wrapping into a dataset:
given_dataset = dataset_fn(distribute_lib.InputContext())
dataset = self._wrap_dataset(input_type, given_dataset, input_workers,
split_batch_by, enable_get_next_as_optional)
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
if context.executing_eagerly():
iterator = iter(dataset)
else:
# In graph mode currently we only have support for creating iterators
# for datasetV1 instances.
if not isinstance(dataset, dataset_ops.DatasetV1):
self.skipTest("unsupported test combination")
iterator = dataset.make_one_shot_iterator()
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
else:
evaluate(control_flow_ops.group(iterator._initializer))
# After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
else:
evaluate(control_flow_ops.group(iterator._initializer))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
if iteration_type == "for_loop" and context.executing_eagerly():
actual_values = []
for x in dataset:
computed_value = self.evaluate(
[values.select_replica(r, x) for r in range(len(devices))])
actual_values.append(computed_value)
for i, expected_value in enumerate(expected_values):
self.assertEqual(len(expected_value), len(actual_values[i]))
for j in range(len(expected_value)):
self.assertAllEqual(expected_value[j], actual_values[i][j])
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
parameterized.TestCase):
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
parameterized.TestCase):
def testGraphModeError(self):
with context.graph_mode():
worker_device_pairs = [("", ["/device:CPU:0"])]
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
dataset = dataset_ops.Dataset.range(10).batch(2)
with self.assertRaisesRegexp(RuntimeError,
"__iter__ is only "
"supported when eager execution is "
"enabled."):
dist_dataset = input_lib.DistributedDatasetV1(dataset, input_workers)
iter(dist_dataset)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"]))
def testOneDeviceCPU(self, input_type):
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"]))
def testOneDeviceCPU(self, input_type, api_type, iteration_type):
worker_device_pairs = [("", ["/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
worker_device_pairs, expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testTwoDevicesOneGPUOneCPU(self, input_type):
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
worker_device_pairs, expected_values)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testTupleDataset(self, input_type):
def testTupleDataset(self, input_type, api_type, iteration_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
if tf2.enabled():
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
else:
dataset1 = dataset_ops.DatasetV2.range(10)
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values)
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
worker_device_pairs, expected_values)
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testUnevenDatasetBatches(self, input_type):
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
# The last global batch only contains data for one replica.
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values, enable_get_next_as_optional=True)
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
worker_device_pairs, expected_values,
enable_get_next_as_optional=True)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
split_batch_by=[None, 2],
required_gpus=1))
def testBatchSplitting(self, input_type, split_batch_by):
def testBatchSplitting(self, input_type, api_type, iteration_type,
split_batch_by):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
batch_size = 10
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size)
@ -176,13 +280,13 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
range(i+updated_batch_size, i+2*updated_batch_size)]
for i in range(0, 100, updated_batch_size*2)]
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None,
split_batch_by=split_batch_by)
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
worker_device_pairs, expected_values, sess=None,
split_batch_by=split_batch_by)
class InputIteratorMultiWorkerTest(
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
class DistributedIteratorMultiWorkerTest(
multi_worker_test_base.MultiWorkerTestBase, DistributedIteratorTestBase,
parameterized.TestCase):
def _cpu_devices(self):
@ -206,77 +310,109 @@ class InputIteratorMultiWorkerTest(
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testOneDevicePerWorker(self, input_type):
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"]))
def testOneDevicePerWorker(self, input_type, api_type, iteration_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
if input_type == "dataset":
# Autosharded
expected_values = [[0, 1], [2, 3]]
else:
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices,
expected_values, sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testTwoDevicesPerWorker(self, input_type):
def testTwoDevicesPerWorker(self, input_type, api_type, iteration_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
if input_type == "dataset":
# Autosharded
expected_values = [[0, 2, 1, 3]]
else:
expected_values = [[0, 1, 0, 1], [2, 3, 2, 3]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices,
expected_values, sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"]))
def testTupleDataset(self, input_type):
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"]))
def testTupleDataset(self, input_type, api_type, iteration_type):
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(4)
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
if tf2.enabled():
dataset1 = dataset_ops.DatasetV2.range(4)
dataset2 = dataset_ops.DatasetV2.range(4).map(lambda x: x**2)
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
else:
dataset1 = dataset_ops.Dataset.range(4)
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
if input_type == "dataset":
# Autosharded
expected_values = [[(0, 0), (1, 1)], [(2, 4), (3, 9)]]
else:
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices, expected_values,
sess)
@combinations.generate(
combinations.combine(
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
def testUnevenDatasetBatches(self, input_type):
mode=["graph"],
input_type=["input_fn", "dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
if input_type == "dataset":
# Autosharded
expected_values = [[[0, 1], [4, 5], [2, 3], [6, 7]], [[8], [], [], []]]
else:
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess,
enable_get_next_as_optional=True)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices, expected_values,
sess, enable_get_next_as_optional=True)
@combinations.generate(
combinations.combine(
mode=["graph"], input_type=["input_fn"], required_gpus=1))
def testDifferentDatasets(self, input_type):
mode=["graph"], input_type=["input_fn"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
required_gpus=1))
def testDifferentDatasets(self, input_type, api_type, iteration_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
@ -288,10 +424,9 @@ class InputIteratorMultiWorkerTest(
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess,
enable_get_next_as_optional=True)
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices, expected_values,
sess, enable_get_next_as_optional=True)
if __name__ == "__main__":
test.main()

View File

@ -597,6 +597,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, input_contexts)
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._num_replicas_in_sync)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._host_input_device, session)

View File

@ -105,6 +105,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
del destinations
return tensor
def _experimental_distribute_dataset(self, dataset):
# Note that split_batch_by argument is not passed because it is always 1 in
# this strategy, and adding it adds unnecessary overhead to the dataset.
return input_lib.get_distributed_dataset(dataset, self._input_workers)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None):

View File

@ -282,6 +282,10 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator(dataset, self._input_workers,
self._num_replicas_in_sync)

View File

@ -325,6 +325,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
numpy_input, numpy_dataset.SingleDevice(self._host_device),
session)
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._num_replicas_in_sync)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
# a mechanism to infer the outputs of `fn`. Pending b/110550782.

View File

@ -24,6 +24,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -24,6 +24,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -24,6 +24,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -24,6 +24,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -24,6 +24,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -28,6 +28,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -22,6 +22,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -23,6 +23,10 @@ tf_class {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"