Make DS InputIterator get_next() API behave correctly in the last partial batch case with multiple devices/workers.
Now there are 3 cases for get_next(): 1. Each replica gets a full batch, the behavior is the same as before. 2. Some replicas get full batches, some get partial batches, and some get no data. get_next() will return a list with tensors from all replicas which include partial batch data and tensors with batch dimension 0 representing no data. 3. If there is no data in any replicas, an OutOfRange error will be triggered. PiperOrigin-RevId: 238491718
This commit is contained in:
parent
0094a3fd25
commit
766a1e0037
@ -35,36 +35,54 @@ from tensorflow.python.util import nest
|
|||||||
|
|
||||||
class InputIteratorTestBase(test.TestCase):
|
class InputIteratorTestBase(test.TestCase):
|
||||||
|
|
||||||
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||||
expected_values, sess=None, split_batch_by=None):
|
devices, split_batch_by):
|
||||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||||
|
|
||||||
if input_type == "input_fn":
|
if input_type == "input_fn":
|
||||||
input_contexts = [
|
input_contexts = []
|
||||||
distribute_lib.InputContext() for _ in worker_device_pairs]
|
for i in range(input_workers.num_workers):
|
||||||
input_fn = lambda _: dataset_fn()
|
input_contexts.append(
|
||||||
iterator = input_lib.InputFunctionIterator(
|
distribute_lib.InputContext(
|
||||||
input_fn, input_workers, input_contexts)
|
num_input_pipelines=input_workers.num_workers,
|
||||||
|
input_pipeline_id=i,
|
||||||
|
num_replicas_in_sync=len(devices)))
|
||||||
|
|
||||||
|
iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers,
|
||||||
|
input_contexts)
|
||||||
else:
|
else:
|
||||||
iterator = input_lib.DatasetIterator(
|
iterator = input_lib.DatasetIterator(
|
||||||
dataset_fn(), input_workers, split_batch_by)
|
dataset_fn(distribute_lib.InputContext()), input_workers,
|
||||||
|
split_batch_by)
|
||||||
|
return iterator
|
||||||
|
|
||||||
|
def _test_iterator(self,
|
||||||
|
input_type,
|
||||||
|
dataset_fn,
|
||||||
|
worker_device_pairs,
|
||||||
|
expected_values,
|
||||||
|
sess=None,
|
||||||
|
split_batch_by=None):
|
||||||
|
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||||
|
iterator = self._create_iterator(
|
||||||
|
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by)
|
||||||
|
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
|
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
computed_value = evaluate(
|
computed_value = evaluate(
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
self.assertAllEqual(expected_value, computed_value)
|
self.assertEqual(len(expected_value), len(computed_value))
|
||||||
|
for i in range(len(expected_value)):
|
||||||
|
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
evaluate([values.select_replica(r, next_element)
|
evaluate(
|
||||||
for r in range(len(devices))])
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
@ -73,7 +91,9 @@ class InputIteratorTestBase(test.TestCase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
computed_value = evaluate(
|
computed_value = evaluate(
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
self.assertAllEqual(expected_value, computed_value)
|
self.assertEqual(len(expected_value), len(computed_value))
|
||||||
|
for i in range(len(expected_value)):
|
||||||
|
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||||
@ -84,7 +104,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
input_type=["input_fn", "dataset"]))
|
input_type=["input_fn", "dataset"]))
|
||||||
def testOneDeviceCPU(self, input_type):
|
def testOneDeviceCPU(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
expected_values = [[i] for i in range(10)]
|
expected_values = [[i] for i in range(10)]
|
||||||
|
|
||||||
@ -97,7 +117,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
required_gpus=1))
|
required_gpus=1))
|
||||||
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
@ -110,7 +130,9 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
required_gpus=1))
|
required_gpus=1))
|
||||||
def testTupleDataset(self, input_type):
|
def testTupleDataset(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
def dataset_fn():
|
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
dataset1 = dataset_ops.Dataset.range(10)
|
dataset1 = dataset_ops.Dataset.range(10)
|
||||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
@ -120,15 +142,17 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
expected_values)
|
expected_values)
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(
|
@combinations.generate(
|
||||||
mode=["graph", "eager"],
|
combinations.combine(
|
||||||
input_type=["input_fn", "dataset"],
|
mode=["graph", "eager"],
|
||||||
required_gpus=1))
|
input_type=["input_fn", "dataset"],
|
||||||
|
required_gpus=1))
|
||||||
def testUnevenDatasetBatches(self, input_type):
|
def testUnevenDatasetBatches(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(11)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
# The last global batch only contains data for one replica.
|
||||||
|
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
expected_values)
|
expected_values)
|
||||||
|
|
||||||
@ -140,7 +164,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
def testBatchSplitting(self, input_type, split_batch_by):
|
def testBatchSplitting(self, input_type, split_batch_by):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||||
|
|
||||||
updated_batch_size = (
|
updated_batch_size = (
|
||||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||||
@ -182,7 +206,7 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testOneDevicePerWorker(self, input_type):
|
def testOneDevicePerWorker(self, input_type):
|
||||||
worker_devices = self._cpu_devices()
|
worker_devices = self._cpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
||||||
|
|
||||||
@ -193,7 +217,7 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testTwoDevicesPerWorker(self, input_type):
|
def testTwoDevicesPerWorker(self, input_type):
|
||||||
worker_devices = self._cpu_and_one_gpu_devices()
|
worker_devices = self._cpu_and_one_gpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
||||||
|
|
||||||
@ -203,7 +227,9 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testTupleDataset(self, input_type):
|
def testTupleDataset(self, input_type):
|
||||||
worker_devices = self._cpu_devices()
|
worker_devices = self._cpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
def dataset_fn():
|
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
dataset1 = dataset_ops.Dataset.range(4)
|
dataset1 = dataset_ops.Dataset.range(4)
|
||||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
@ -212,6 +238,36 @@ class InputIteratorMultiWorkerTest(
|
|||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
expected_values, sess)
|
expected_values, sess)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
|
||||||
|
def testUnevenDatasetBatches(self, input_type):
|
||||||
|
worker_devices = self._cpu_and_one_gpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||||
|
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||||
|
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
|
expected_values, sess)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["graph"], input_type=["input_fn"], required_gpus=1))
|
||||||
|
def testDifferentDatasets(self, input_type):
|
||||||
|
worker_devices = self._cpu_and_one_gpu_devices()
|
||||||
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
if ctx.input_pipeline_id == 0:
|
||||||
|
return dataset_ops.Dataset.range(8).batch(2)
|
||||||
|
else:
|
||||||
|
return dataset_ops.Dataset.range(9).batch(2)
|
||||||
|
|
||||||
|
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||||
|
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
|
||||||
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
|
expected_values, sess)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -21,14 +21,21 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.data.experimental.ops import batching
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_ops
|
from tensorflow.python.distribute import input_ops
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
class InputWorkers(object):
|
class InputWorkers(object):
|
||||||
@ -126,6 +133,7 @@ class InputIteratorImpl(InputIterator):
|
|||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Returns the next input from the iterator for all replicas."""
|
"""Returns the next input from the iterator for all replicas."""
|
||||||
replicas = []
|
replicas = []
|
||||||
|
worker_has_values = []
|
||||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
d = tf_device.DeviceSpec.from_string(worker)
|
d = tf_device.DeviceSpec.from_string(worker)
|
||||||
@ -133,8 +141,61 @@ class InputIteratorImpl(InputIterator):
|
|||||||
else:
|
else:
|
||||||
new_name = None
|
new_name = None
|
||||||
with ops.device(worker):
|
with ops.device(worker):
|
||||||
|
worker_has_value, next_element = (
|
||||||
|
self._iterators[i].get_next_as_list(new_name))
|
||||||
|
worker_has_values.append(worker_has_value)
|
||||||
# Make `replicas` a flat list of values across all replicas.
|
# Make `replicas` a flat list of values across all replicas.
|
||||||
replicas.extend(self._iterators[i].get_next_as_list(new_name))
|
replicas.append(next_element)
|
||||||
|
|
||||||
|
out_of_range_replicas = []
|
||||||
|
|
||||||
|
def out_of_range_fn(worker_index, device):
|
||||||
|
"""This function will throw an OutOfRange error."""
|
||||||
|
# As this will be only called when there is no data left, so calling
|
||||||
|
# get_next() will trigger an OutOfRange error.
|
||||||
|
data = self._iterators[worker_index].get_next(device)
|
||||||
|
out_of_range_replicas.append(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
# `global_has_value` indicates whether there is data in this global batch.
|
||||||
|
# We do a all-reduce across all the workers in the multi-worker case.
|
||||||
|
# TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
|
||||||
|
if len(worker_has_values) > 1:
|
||||||
|
with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
|
||||||
|
# Place the tf.reduce_any op in device 0 to minimize communication
|
||||||
|
# cost.
|
||||||
|
# TODO(b/128545270): Investigate why placing it on worker 0 will cause
|
||||||
|
# the entire data to copy back from device to host.
|
||||||
|
global_has_value = math_ops.reduce_any(worker_has_values)
|
||||||
|
else:
|
||||||
|
global_has_value = worker_has_values[0]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||||
|
with ops.device(worker):
|
||||||
|
devices = self._input_workers.compute_devices_for_worker(i)
|
||||||
|
for j, device in enumerate(devices):
|
||||||
|
with ops.device(device):
|
||||||
|
# pylint: disable=undefined-loop-variable
|
||||||
|
# pylint: disable=cell-var-from-loop
|
||||||
|
# It is fine for the lambda to capture variables from the loop as
|
||||||
|
# the lambda is executed in the loop as well.
|
||||||
|
result = control_flow_ops.cond(global_has_value,
|
||||||
|
lambda: replicas[i][j],
|
||||||
|
lambda: out_of_range_fn(i, device))
|
||||||
|
# pylint: enable=cell-var-from-loop
|
||||||
|
# pylint: enable=undefined-loop-variable
|
||||||
|
results.append(result)
|
||||||
|
replicas = results
|
||||||
|
|
||||||
|
# Some dimensions in `replicas` will become unknown after we conditionally
|
||||||
|
# return the real tensors or the dummy tensors. We fix the input shapes by
|
||||||
|
# using the shapes from `out_of_range_replicas` because it is calling
|
||||||
|
# get_next() inside.
|
||||||
|
flattened_replicas = nest.flatten(replicas)
|
||||||
|
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
|
||||||
|
flattened_replicas[i].set_shape(replica_data.get_shape())
|
||||||
|
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
|
||||||
|
|
||||||
return values.regroup(self._input_workers.device_map, replicas)
|
return values.regroup(self._input_workers.device_map, replicas)
|
||||||
|
|
||||||
@ -265,6 +326,45 @@ class DatasetIterator(InputIteratorImpl):
|
|||||||
super(DatasetIterator, self).__init__(input_workers, iterators)
|
super(DatasetIterator, self).__init__(input_workers, iterators)
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_tensor_fn(value_structure):
|
||||||
|
"""A function to create dummy tensors from `value_structure`."""
|
||||||
|
|
||||||
|
def create_dummy_tensor(feature_shape, feature_type):
|
||||||
|
"""Create a dummy tensor with possible batch dimensions set to 0."""
|
||||||
|
|
||||||
|
# Ideally we should set the batch dimension to 0, however as in
|
||||||
|
# DistributionStrategy we don't know the batch dimension, we try to
|
||||||
|
# guess it as much as possible. If the feature has unknown dimensions, we
|
||||||
|
# will set them to 0. If the feature shape is already static, we guess the
|
||||||
|
# first dimension as batch dimension and set it to 0.
|
||||||
|
dims = []
|
||||||
|
for dim in feature_shape.dims:
|
||||||
|
if dim.value is None:
|
||||||
|
dims.append(tensor_shape.Dimension(0))
|
||||||
|
else:
|
||||||
|
dims.append(dim)
|
||||||
|
if feature_shape.is_fully_defined() and dims:
|
||||||
|
dims[0] = tensor_shape.Dimension(0)
|
||||||
|
|
||||||
|
# Create the dummy tensor.
|
||||||
|
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
|
||||||
|
return dummy_tensor
|
||||||
|
|
||||||
|
result = []
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for feature_shape, feature_type in zip(value_structure._flat_shapes,
|
||||||
|
value_structure._flat_types):
|
||||||
|
result.append(create_dummy_tensor(feature_shape, feature_type))
|
||||||
|
|
||||||
|
if isinstance(value_structure, structure.NestedStructure):
|
||||||
|
result = nest.pack_sequence_as(value_structure._nested_structure, result)
|
||||||
|
else:
|
||||||
|
result = result[0]
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class _SingleWorkerDatasetIterator(object):
|
class _SingleWorkerDatasetIterator(object):
|
||||||
"""Iterator for a single `tf.data.Dataset`."""
|
"""Iterator for a single `tf.data.Dataset`."""
|
||||||
|
|
||||||
@ -290,12 +390,51 @@ class _SingleWorkerDatasetIterator(object):
|
|||||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
self._dataset, self._devices)
|
self._dataset, self._devices)
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
def get_next(self, device, name=None):
|
||||||
"""Get next element from the underlying iterator."""
|
"""Get next element for the given device."""
|
||||||
del name
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
data_list = self._iterator.get_next()
|
return self._iterator.get_next(device)
|
||||||
return data_list
|
|
||||||
|
def get_next_as_list(self, name=None):
|
||||||
|
"""Get next element from underlying iterator.
|
||||||
|
|
||||||
|
If there is no data left, a list of dummy tensors with possible batch
|
||||||
|
dimensions set to 0 will be returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: not used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean tensor indicates whether there is any data in next element and
|
||||||
|
the real data as the next element or a list of dummy tensors if no data
|
||||||
|
left.
|
||||||
|
"""
|
||||||
|
del name
|
||||||
|
with ops.device(self._worker):
|
||||||
|
data_list = self._iterator.get_next_as_optional()
|
||||||
|
result = []
|
||||||
|
for i, data in enumerate(data_list):
|
||||||
|
# Place the condition op in the same device as the data so the data
|
||||||
|
# doesn't need to be sent back to the worker.
|
||||||
|
with ops.device(self._devices[i]):
|
||||||
|
# As MultiDeviceIterator will fetch data in order, so we only need to
|
||||||
|
# check if the first replica has value to see whether there is data
|
||||||
|
# left for this single worker.
|
||||||
|
if i == 0:
|
||||||
|
worker_has_value = data.has_value()
|
||||||
|
|
||||||
|
# pylint: disable=unnecessary-lambda
|
||||||
|
# pylint: disable=cell-var-from-loop
|
||||||
|
real_data = control_flow_ops.cond(
|
||||||
|
data.has_value(),
|
||||||
|
lambda: data.get_value(),
|
||||||
|
lambda: _dummy_tensor_fn(data.value_structure))
|
||||||
|
result.append(real_data)
|
||||||
|
# pylint: enable=cell-var-from-loop
|
||||||
|
# pylint: enable=unnecessary-lambda
|
||||||
|
|
||||||
|
return worker_has_value, result
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialze underlying iterator.
|
"""Initialze underlying iterator.
|
||||||
@ -334,12 +473,18 @@ class _SingleWorkerCallableIterator(object):
|
|||||||
self._worker = worker
|
self._worker = worker
|
||||||
self._devices = devices
|
self._devices = devices
|
||||||
|
|
||||||
|
def get_next(self, device, name=None):
|
||||||
|
"""Get next element for the given device from the callable."""
|
||||||
|
del device, name
|
||||||
|
with ops.device(self._worker):
|
||||||
|
return self._fn()
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
def get_next_as_list(self, name=None):
|
||||||
"""Get next element from the callable."""
|
"""Get next element from the callable."""
|
||||||
del name
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
data_list = [self._fn() for _ in self._devices]
|
data_list = [self._fn() for _ in self._devices]
|
||||||
return data_list
|
return constant_op.constant(True), data_list
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
# TODO(petebu) Should this throw an exception instead?
|
# TODO(petebu) Should this throw an exception instead?
|
||||||
|
Loading…
Reference in New Issue
Block a user