parent
18493fb2a3
commit
e18f3efa34
@ -35,54 +35,36 @@ from tensorflow.python.util import nest
|
||||
|
||||
class InputIteratorTestBase(test.TestCase):
|
||||
|
||||
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||
devices, split_batch_by):
|
||||
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values, sess=None, split_batch_by=None):
|
||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||
|
||||
if input_type == "input_fn":
|
||||
input_contexts = []
|
||||
for i in range(input_workers.num_workers):
|
||||
input_contexts.append(
|
||||
distribute_lib.InputContext(
|
||||
num_input_pipelines=input_workers.num_workers,
|
||||
input_pipeline_id=i,
|
||||
num_replicas_in_sync=len(devices)))
|
||||
|
||||
iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers,
|
||||
input_contexts)
|
||||
input_contexts = [
|
||||
distribute_lib.InputContext() for _ in worker_device_pairs]
|
||||
input_fn = lambda _: dataset_fn()
|
||||
iterator = input_lib.InputFunctionIterator(
|
||||
input_fn, input_workers, input_contexts)
|
||||
else:
|
||||
iterator = input_lib.DatasetIterator(
|
||||
dataset_fn(distribute_lib.InputContext()), input_workers,
|
||||
split_batch_by)
|
||||
return iterator
|
||||
|
||||
def _test_iterator(self,
|
||||
input_type,
|
||||
dataset_fn,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
sess=None,
|
||||
split_batch_by=None):
|
||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||
iterator = self._create_iterator(
|
||||
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by)
|
||||
dataset_fn(), input_workers, split_batch_by)
|
||||
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
self.assertAllEqual(expected_value, computed_value)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
evaluate([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
@ -91,9 +73,7 @@ class InputIteratorTestBase(test.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
self.assertAllEqual(expected_value, computed_value)
|
||||
|
||||
|
||||
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
@ -104,7 +84,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
input_type=["input_fn", "dataset"]))
|
||||
def testOneDeviceCPU(self, input_type):
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||
|
||||
expected_values = [[i] for i in range(10)]
|
||||
|
||||
@ -117,7 +97,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
required_gpus=1))
|
||||
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||
|
||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||
|
||||
@ -130,9 +110,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
required_gpus=1))
|
||||
def testTupleDataset(self, input_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
def dataset_fn():
|
||||
dataset1 = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
@ -142,17 +120,15 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
required_gpus=1))
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
required_gpus=1))
|
||||
def testUnevenDatasetBatches(self, input_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(11)
|
||||
|
||||
# The last global batch only contains data for one replica.
|
||||
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
|
||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values)
|
||||
|
||||
@ -164,7 +140,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
def testBatchSplitting(self, input_type, split_batch_by):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
batch_size = 10
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||
|
||||
updated_batch_size = (
|
||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||
@ -206,7 +182,7 @@ class InputIteratorMultiWorkerTest(
|
||||
def testOneDevicePerWorker(self, input_type):
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
[[0, 0], [1, 1], [2, 2], [3, 3]], sess)
|
||||
|
||||
@ -217,7 +193,7 @@ class InputIteratorMultiWorkerTest(
|
||||
def testTwoDevicesPerWorker(self, input_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
[[0, 1, 0, 1], [2, 3, 2, 3]], sess)
|
||||
|
||||
@ -227,9 +203,7 @@ class InputIteratorMultiWorkerTest(
|
||||
def testTupleDataset(self, input_type):
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
def dataset_fn():
|
||||
dataset1 = dataset_ops.Dataset.range(4)
|
||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
@ -238,36 +212,6 @@ class InputIteratorMultiWorkerTest(
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
|
||||
def testUnevenDatasetBatches(self, input_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph"], input_type=["input_fn"], required_gpus=1))
|
||||
def testDifferentDatasets(self, input_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
|
||||
def dataset_fn(ctx):
|
||||
if ctx.input_pipeline_id == 0:
|
||||
return dataset_ops.Dataset.range(8).batch(2)
|
||||
else:
|
||||
return dataset_ops.Dataset.range(9).batch(2)
|
||||
|
||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -21,21 +21,14 @@ from __future__ import print_function
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class InputWorkers(object):
|
||||
@ -133,7 +126,6 @@ class InputIteratorImpl(InputIterator):
|
||||
def get_next(self, name=None):
|
||||
"""Returns the next input from the iterator for all replicas."""
|
||||
replicas = []
|
||||
worker_has_values = []
|
||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||
if name is not None:
|
||||
d = tf_device.DeviceSpec.from_string(worker)
|
||||
@ -141,61 +133,8 @@ class InputIteratorImpl(InputIterator):
|
||||
else:
|
||||
new_name = None
|
||||
with ops.device(worker):
|
||||
worker_has_value, next_element = (
|
||||
self._iterators[i].get_next_as_list(new_name))
|
||||
worker_has_values.append(worker_has_value)
|
||||
# Make `replicas` a flat list of values across all replicas.
|
||||
replicas.append(next_element)
|
||||
|
||||
out_of_range_replicas = []
|
||||
|
||||
def out_of_range_fn(worker_index, device):
|
||||
"""This function will throw an OutOfRange error."""
|
||||
# As this will be only called when there is no data left, so calling
|
||||
# get_next() will trigger an OutOfRange error.
|
||||
data = self._iterators[worker_index].get_next(device)
|
||||
out_of_range_replicas.append(data)
|
||||
return data
|
||||
|
||||
# `global_has_value` indicates whether there is data in this global batch.
|
||||
# We do a all-reduce across all the workers in the multi-worker case.
|
||||
# TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
|
||||
if len(worker_has_values) > 1:
|
||||
with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
|
||||
# Place the tf.reduce_any op in device 0 to minimize communication
|
||||
# cost.
|
||||
# TODO(b/128545270): Investigate why placing it on worker 0 will cause
|
||||
# the entire data to copy back from device to host.
|
||||
global_has_value = math_ops.reduce_any(worker_has_values)
|
||||
else:
|
||||
global_has_value = worker_has_values[0]
|
||||
|
||||
results = []
|
||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
devices = self._input_workers.compute_devices_for_worker(i)
|
||||
for j, device in enumerate(devices):
|
||||
with ops.device(device):
|
||||
# pylint: disable=undefined-loop-variable
|
||||
# pylint: disable=cell-var-from-loop
|
||||
# It is fine for the lambda to capture variables from the loop as
|
||||
# the lambda is executed in the loop as well.
|
||||
result = control_flow_ops.cond(global_has_value,
|
||||
lambda: replicas[i][j],
|
||||
lambda: out_of_range_fn(i, device))
|
||||
# pylint: enable=cell-var-from-loop
|
||||
# pylint: enable=undefined-loop-variable
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
# Some dimensions in `replicas` will become unknown after we conditionally
|
||||
# return the real tensors or the dummy tensors. We fix the input shapes by
|
||||
# using the shapes from `out_of_range_replicas` because it is calling
|
||||
# get_next() inside.
|
||||
flattened_replicas = nest.flatten(replicas)
|
||||
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
|
||||
flattened_replicas[i].set_shape(replica_data.get_shape())
|
||||
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
|
||||
replicas.extend(self._iterators[i].get_next_as_list(new_name))
|
||||
|
||||
return values.regroup(self._input_workers.device_map, replicas)
|
||||
|
||||
@ -326,45 +265,6 @@ class DatasetIterator(InputIteratorImpl):
|
||||
super(DatasetIterator, self).__init__(input_workers, iterators)
|
||||
|
||||
|
||||
def _dummy_tensor_fn(value_structure):
|
||||
"""A function to create dummy tensors from `value_structure`."""
|
||||
|
||||
def create_dummy_tensor(feature_shape, feature_type):
|
||||
"""Create a dummy tensor with possible batch dimensions set to 0."""
|
||||
|
||||
# Ideally we should set the batch dimension to 0, however as in
|
||||
# DistributionStrategy we don't know the batch dimension, we try to
|
||||
# guess it as much as possible. If the feature has unknown dimensions, we
|
||||
# will set them to 0. If the feature shape is already static, we guess the
|
||||
# first dimension as batch dimension and set it to 0.
|
||||
dims = []
|
||||
for dim in feature_shape.dims:
|
||||
if dim.value is None:
|
||||
dims.append(tensor_shape.Dimension(0))
|
||||
else:
|
||||
dims.append(dim)
|
||||
if feature_shape.is_fully_defined() and dims:
|
||||
dims[0] = tensor_shape.Dimension(0)
|
||||
|
||||
# Create the dummy tensor.
|
||||
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
|
||||
return dummy_tensor
|
||||
|
||||
result = []
|
||||
# pylint: disable=protected-access
|
||||
for feature_shape, feature_type in zip(value_structure._flat_shapes,
|
||||
value_structure._flat_types):
|
||||
result.append(create_dummy_tensor(feature_shape, feature_type))
|
||||
|
||||
if isinstance(value_structure, structure.NestedStructure):
|
||||
result = nest.pack_sequence_as(value_structure._nested_structure, result)
|
||||
else:
|
||||
result = result[0]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class _SingleWorkerDatasetIterator(object):
|
||||
"""Iterator for a single `tf.data.Dataset`."""
|
||||
|
||||
@ -390,51 +290,12 @@ class _SingleWorkerDatasetIterator(object):
|
||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._devices)
|
||||
|
||||
def get_next(self, device, name=None):
|
||||
"""Get next element for the given device."""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
return self._iterator.get_next(device)
|
||||
|
||||
def get_next_as_list(self, name=None):
|
||||
"""Get next element from underlying iterator.
|
||||
|
||||
If there is no data left, a list of dummy tensors with possible batch
|
||||
dimensions set to 0 will be returned.
|
||||
|
||||
Args:
|
||||
name: not used.
|
||||
|
||||
Returns:
|
||||
A boolean tensor indicates whether there is any data in next element and
|
||||
the real data as the next element or a list of dummy tensors if no data
|
||||
left.
|
||||
"""
|
||||
"""Get next element from the underlying iterator."""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
data_list = self._iterator.get_next_as_optional()
|
||||
result = []
|
||||
for i, data in enumerate(data_list):
|
||||
# Place the condition op in the same device as the data so the data
|
||||
# doesn't need to be sent back to the worker.
|
||||
with ops.device(self._devices[i]):
|
||||
# As MultiDeviceIterator will fetch data in order, so we only need to
|
||||
# check if the first replica has value to see whether there is data
|
||||
# left for this single worker.
|
||||
if i == 0:
|
||||
worker_has_value = data.has_value()
|
||||
|
||||
# pylint: disable=unnecessary-lambda
|
||||
# pylint: disable=cell-var-from-loop
|
||||
real_data = control_flow_ops.cond(
|
||||
data.has_value(),
|
||||
lambda: data.get_value(),
|
||||
lambda: _dummy_tensor_fn(data.value_structure))
|
||||
result.append(real_data)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
return worker_has_value, result
|
||||
data_list = self._iterator.get_next()
|
||||
return data_list
|
||||
|
||||
def initialize(self):
|
||||
"""Initialze underlying iterator.
|
||||
@ -473,18 +334,12 @@ class _SingleWorkerCallableIterator(object):
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
|
||||
def get_next(self, device, name=None):
|
||||
"""Get next element for the given device from the callable."""
|
||||
del device, name
|
||||
with ops.device(self._worker):
|
||||
return self._fn()
|
||||
|
||||
def get_next_as_list(self, name=None):
|
||||
"""Get next element from the callable."""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
data_list = [self._fn() for _ in self._devices]
|
||||
return constant_op.constant(True), data_list
|
||||
return data_list
|
||||
|
||||
def initialize(self):
|
||||
# TODO(petebu) Should this throw an exception instead?
|
||||
|
Loading…
Reference in New Issue
Block a user