parent
18493fb2a3
commit
e18f3efa34
@ -35,54 +35,36 @@ from tensorflow.python.util import nest
|
|||||||
|
|
||||||
class InputIteratorTestBase(test.TestCase):
|
class InputIteratorTestBase(test.TestCase):
|
||||||
|
|
||||||
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||||
devices, split_batch_by):
|
expected_values, sess=None, split_batch_by=None):
|
||||||
|
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||||
device_map = values.ReplicaDeviceMap(devices)
|
device_map = values.ReplicaDeviceMap(devices)
|
||||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||||
|
|
||||||
if input_type == "input_fn":
|
if input_type == "input_fn":
|
||||||
input_contexts = []
|
input_contexts = [
|
||||||
for i in range(input_workers.num_workers):
|
distribute_lib.InputContext() for _ in worker_device_pairs]
|
||||||
input_contexts.append(
|
input_fn = lambda _: dataset_fn()
|
||||||
distribute_lib.InputContext(
|
iterator = input_lib.InputFunctionIterator(
|
||||||
num_input_pipelines=input_workers.num_workers,
|
input_fn, input_workers, input_contexts)
|
||||||
input_pipeline_id=i,
|
|
||||||
num_replicas_in_sync=len(devices)))
|
|
||||||
|
|
||||||
iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers,
|
|
||||||
input_contexts)
|
|
||||||
else:
|
else:
|
||||||
iterator = input_lib.DatasetIterator(
|
iterator = input_lib.DatasetIterator(
|
||||||
dataset_fn(distribute_lib.InputContext()), input_workers,
|
dataset_fn(), input_workers, split_batch_by)
|
||||||
split_batch_by)
|
|
||||||
return iterator
|
|
||||||
|
|
||||||
def _test_iterator(self,
|
|
||||||
input_type,
|
|
||||||
dataset_fn,
|
|
||||||
worker_device_pairs,
|
|
||||||
expected_values,
|
|
||||||
sess=None,
|
|
||||||
split_batch_by=None):
|
|
||||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
|
||||||
iterator = self._create_iterator(
|
|
||||||
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by)
|
|
||||||
|
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
|
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
computed_value = evaluate(
|
computed_value = evaluate(
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
self.assertEqual(len(expected_value), len(computed_value))
|
self.assertAllEqual(expected_value, computed_value)
|
||||||
for i in range(len(expected_value)):
|
|
||||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
evaluate(
|
evaluate([values.select_replica(r, next_element)
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
for r in range(len(devices))])
|
||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||||
@ -91,9 +73,7 @@ class InputIteratorTestBase(test.TestCase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
computed_value = evaluate(
|
computed_value = evaluate(
|
||||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||||
self.assertEqual(len(expected_value), len(computed_value))
|
self.assertAllEqual(expected_value, computed_value)
|
||||||
for i in range(len(expected_value)):
|
|
||||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
|
||||||
|
|
||||||
|
|
||||||
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||||
@ -104,7 +84,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
input_type=["input_fn", "dataset"]))
|
input_type=["input_fn", "dataset"]))
|
||||||
def testOneDeviceCPU(self, input_type):
|
def testOneDeviceCPU(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
expected_values = [[i] for i in range(10)]
|
expected_values = [[i] for i in range(10)]
|
||||||
|
|
||||||
@ -117,7 +97,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
required_gpus=1))
|
required_gpus=1))
|
||||||
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
|
|
||||||
@ -130,9 +110,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
required_gpus=1))
|
required_gpus=1))
|
||||||
def testTupleDataset(self, input_type):
|
def testTupleDataset(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
|
def dataset_fn():
|
||||||
def dataset_fn(ctx):
|
|
||||||
del ctx
|
|
||||||
dataset1 = dataset_ops.Dataset.range(10)
|
dataset1 = dataset_ops.Dataset.range(10)
|
||||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
@ -142,17 +120,15 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
expected_values)
|
expected_values)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(combinations.combine(
|
||||||
combinations.combine(
|
mode=["graph", "eager"],
|
||||||
mode=["graph", "eager"],
|
input_type=["input_fn", "dataset"],
|
||||||
input_type=["input_fn", "dataset"],
|
required_gpus=1))
|
||||||
required_gpus=1))
|
|
||||||
def testUnevenDatasetBatches(self, input_type):
|
def testUnevenDatasetBatches(self, input_type):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
dataset_fn = lambda: dataset_ops.Dataset.range(11)
|
||||||
|
|
||||||
# The last global batch only contains data for one replica.
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||||
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||||
expected_values)
|
expected_values)
|
||||||
|
|
||||||
@ -164,7 +140,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
|||||||
def testBatchSplitting(self, input_type, split_batch_by):
|
def testBatchSplitting(self, input_type, split_batch_by):
|
||||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||||
|
|
||||||
updated_batch_size = (
|
updated_batch_size = (
|
||||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||||
@ -206,7 +182,7 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testOneDevicePerWorker(self, input_type):
|
def testOneDevicePerWorker(self, input_type):
|
||||||
worker_devices = self._cpu_devices()
|
worker_devices = self._cpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
||||||
|
|
||||||
@ -217,7 +193,7 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testTwoDevicesPerWorker(self, input_type):
|
def testTwoDevicesPerWorker(self, input_type):
|
||||||
worker_devices = self._cpu_and_one_gpu_devices()
|
worker_devices = self._cpu_and_one_gpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
||||||
|
|
||||||
@ -227,9 +203,7 @@ class InputIteratorMultiWorkerTest(
|
|||||||
def testTupleDataset(self, input_type):
|
def testTupleDataset(self, input_type):
|
||||||
worker_devices = self._cpu_devices()
|
worker_devices = self._cpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
def dataset_fn():
|
||||||
def dataset_fn(ctx):
|
|
||||||
del ctx
|
|
||||||
dataset1 = dataset_ops.Dataset.range(4)
|
dataset1 = dataset_ops.Dataset.range(4)
|
||||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
@ -238,36 +212,6 @@ class InputIteratorMultiWorkerTest(
|
|||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||||
expected_values, sess)
|
expected_values, sess)
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.combine(
|
|
||||||
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
|
|
||||||
def testUnevenDatasetBatches(self, input_type):
|
|
||||||
worker_devices = self._cpu_and_one_gpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
|
||||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
|
||||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
|
||||||
expected_values, sess)
|
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.combine(
|
|
||||||
mode=["graph"], input_type=["input_fn"], required_gpus=1))
|
|
||||||
def testDifferentDatasets(self, input_type):
|
|
||||||
worker_devices = self._cpu_and_one_gpu_devices()
|
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
|
||||||
|
|
||||||
def dataset_fn(ctx):
|
|
||||||
if ctx.input_pipeline_id == 0:
|
|
||||||
return dataset_ops.Dataset.range(8).batch(2)
|
|
||||||
else:
|
|
||||||
return dataset_ops.Dataset.range(9).batch(2)
|
|
||||||
|
|
||||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
|
||||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
|
|
||||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
|
||||||
expected_values, sess)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -21,21 +21,14 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.data.experimental.ops import batching
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
from tensorflow.python.data.util import structure
|
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_ops
|
from tensorflow.python.distribute import input_ops
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
class InputWorkers(object):
|
class InputWorkers(object):
|
||||||
@ -133,7 +126,6 @@ class InputIteratorImpl(InputIterator):
|
|||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Returns the next input from the iterator for all replicas."""
|
"""Returns the next input from the iterator for all replicas."""
|
||||||
replicas = []
|
replicas = []
|
||||||
worker_has_values = []
|
|
||||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
d = tf_device.DeviceSpec.from_string(worker)
|
d = tf_device.DeviceSpec.from_string(worker)
|
||||||
@ -141,61 +133,8 @@ class InputIteratorImpl(InputIterator):
|
|||||||
else:
|
else:
|
||||||
new_name = None
|
new_name = None
|
||||||
with ops.device(worker):
|
with ops.device(worker):
|
||||||
worker_has_value, next_element = (
|
|
||||||
self._iterators[i].get_next_as_list(new_name))
|
|
||||||
worker_has_values.append(worker_has_value)
|
|
||||||
# Make `replicas` a flat list of values across all replicas.
|
# Make `replicas` a flat list of values across all replicas.
|
||||||
replicas.append(next_element)
|
replicas.extend(self._iterators[i].get_next_as_list(new_name))
|
||||||
|
|
||||||
out_of_range_replicas = []
|
|
||||||
|
|
||||||
def out_of_range_fn(worker_index, device):
|
|
||||||
"""This function will throw an OutOfRange error."""
|
|
||||||
# As this will be only called when there is no data left, so calling
|
|
||||||
# get_next() will trigger an OutOfRange error.
|
|
||||||
data = self._iterators[worker_index].get_next(device)
|
|
||||||
out_of_range_replicas.append(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
# `global_has_value` indicates whether there is data in this global batch.
|
|
||||||
# We do a all-reduce across all the workers in the multi-worker case.
|
|
||||||
# TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
|
|
||||||
if len(worker_has_values) > 1:
|
|
||||||
with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
|
|
||||||
# Place the tf.reduce_any op in device 0 to minimize communication
|
|
||||||
# cost.
|
|
||||||
# TODO(b/128545270): Investigate why placing it on worker 0 will cause
|
|
||||||
# the entire data to copy back from device to host.
|
|
||||||
global_has_value = math_ops.reduce_any(worker_has_values)
|
|
||||||
else:
|
|
||||||
global_has_value = worker_has_values[0]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
|
||||||
with ops.device(worker):
|
|
||||||
devices = self._input_workers.compute_devices_for_worker(i)
|
|
||||||
for j, device in enumerate(devices):
|
|
||||||
with ops.device(device):
|
|
||||||
# pylint: disable=undefined-loop-variable
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
# It is fine for the lambda to capture variables from the loop as
|
|
||||||
# the lambda is executed in the loop as well.
|
|
||||||
result = control_flow_ops.cond(global_has_value,
|
|
||||||
lambda: replicas[i][j],
|
|
||||||
lambda: out_of_range_fn(i, device))
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
# pylint: enable=undefined-loop-variable
|
|
||||||
results.append(result)
|
|
||||||
replicas = results
|
|
||||||
|
|
||||||
# Some dimensions in `replicas` will become unknown after we conditionally
|
|
||||||
# return the real tensors or the dummy tensors. We fix the input shapes by
|
|
||||||
# using the shapes from `out_of_range_replicas` because it is calling
|
|
||||||
# get_next() inside.
|
|
||||||
flattened_replicas = nest.flatten(replicas)
|
|
||||||
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
|
|
||||||
flattened_replicas[i].set_shape(replica_data.get_shape())
|
|
||||||
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
|
|
||||||
|
|
||||||
return values.regroup(self._input_workers.device_map, replicas)
|
return values.regroup(self._input_workers.device_map, replicas)
|
||||||
|
|
||||||
@ -326,45 +265,6 @@ class DatasetIterator(InputIteratorImpl):
|
|||||||
super(DatasetIterator, self).__init__(input_workers, iterators)
|
super(DatasetIterator, self).__init__(input_workers, iterators)
|
||||||
|
|
||||||
|
|
||||||
def _dummy_tensor_fn(value_structure):
|
|
||||||
"""A function to create dummy tensors from `value_structure`."""
|
|
||||||
|
|
||||||
def create_dummy_tensor(feature_shape, feature_type):
|
|
||||||
"""Create a dummy tensor with possible batch dimensions set to 0."""
|
|
||||||
|
|
||||||
# Ideally we should set the batch dimension to 0, however as in
|
|
||||||
# DistributionStrategy we don't know the batch dimension, we try to
|
|
||||||
# guess it as much as possible. If the feature has unknown dimensions, we
|
|
||||||
# will set them to 0. If the feature shape is already static, we guess the
|
|
||||||
# first dimension as batch dimension and set it to 0.
|
|
||||||
dims = []
|
|
||||||
for dim in feature_shape.dims:
|
|
||||||
if dim.value is None:
|
|
||||||
dims.append(tensor_shape.Dimension(0))
|
|
||||||
else:
|
|
||||||
dims.append(dim)
|
|
||||||
if feature_shape.is_fully_defined() and dims:
|
|
||||||
dims[0] = tensor_shape.Dimension(0)
|
|
||||||
|
|
||||||
# Create the dummy tensor.
|
|
||||||
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
|
|
||||||
return dummy_tensor
|
|
||||||
|
|
||||||
result = []
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
for feature_shape, feature_type in zip(value_structure._flat_shapes,
|
|
||||||
value_structure._flat_types):
|
|
||||||
result.append(create_dummy_tensor(feature_shape, feature_type))
|
|
||||||
|
|
||||||
if isinstance(value_structure, structure.NestedStructure):
|
|
||||||
result = nest.pack_sequence_as(value_structure._nested_structure, result)
|
|
||||||
else:
|
|
||||||
result = result[0]
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class _SingleWorkerDatasetIterator(object):
|
class _SingleWorkerDatasetIterator(object):
|
||||||
"""Iterator for a single `tf.data.Dataset`."""
|
"""Iterator for a single `tf.data.Dataset`."""
|
||||||
|
|
||||||
@ -390,51 +290,12 @@ class _SingleWorkerDatasetIterator(object):
|
|||||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
self._dataset, self._devices)
|
self._dataset, self._devices)
|
||||||
|
|
||||||
def get_next(self, device, name=None):
|
|
||||||
"""Get next element for the given device."""
|
|
||||||
del name
|
|
||||||
with ops.device(self._worker):
|
|
||||||
return self._iterator.get_next(device)
|
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
def get_next_as_list(self, name=None):
|
||||||
"""Get next element from underlying iterator.
|
"""Get next element from the underlying iterator."""
|
||||||
|
|
||||||
If there is no data left, a list of dummy tensors with possible batch
|
|
||||||
dimensions set to 0 will be returned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: not used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A boolean tensor indicates whether there is any data in next element and
|
|
||||||
the real data as the next element or a list of dummy tensors if no data
|
|
||||||
left.
|
|
||||||
"""
|
|
||||||
del name
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
data_list = self._iterator.get_next_as_optional()
|
data_list = self._iterator.get_next()
|
||||||
result = []
|
return data_list
|
||||||
for i, data in enumerate(data_list):
|
|
||||||
# Place the condition op in the same device as the data so the data
|
|
||||||
# doesn't need to be sent back to the worker.
|
|
||||||
with ops.device(self._devices[i]):
|
|
||||||
# As MultiDeviceIterator will fetch data in order, so we only need to
|
|
||||||
# check if the first replica has value to see whether there is data
|
|
||||||
# left for this single worker.
|
|
||||||
if i == 0:
|
|
||||||
worker_has_value = data.has_value()
|
|
||||||
|
|
||||||
# pylint: disable=unnecessary-lambda
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
real_data = control_flow_ops.cond(
|
|
||||||
data.has_value(),
|
|
||||||
lambda: data.get_value(),
|
|
||||||
lambda: _dummy_tensor_fn(data.value_structure))
|
|
||||||
result.append(real_data)
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
# pylint: enable=unnecessary-lambda
|
|
||||||
|
|
||||||
return worker_has_value, result
|
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialze underlying iterator.
|
"""Initialze underlying iterator.
|
||||||
@ -473,18 +334,12 @@ class _SingleWorkerCallableIterator(object):
|
|||||||
self._worker = worker
|
self._worker = worker
|
||||||
self._devices = devices
|
self._devices = devices
|
||||||
|
|
||||||
def get_next(self, device, name=None):
|
|
||||||
"""Get next element for the given device from the callable."""
|
|
||||||
del device, name
|
|
||||||
with ops.device(self._worker):
|
|
||||||
return self._fn()
|
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
def get_next_as_list(self, name=None):
|
||||||
"""Get next element from the callable."""
|
"""Get next element from the callable."""
|
||||||
del name
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
data_list = [self._fn() for _ in self._devices]
|
data_list = [self._fn() for _ in self._devices]
|
||||||
return constant_op.constant(True), data_list
|
return data_list
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
# TODO(petebu) Should this throw an exception instead?
|
# TODO(petebu) Should this throw an exception instead?
|
||||||
|
Loading…
Reference in New Issue
Block a user