diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index 204f52b034f..80a1c7bae8f 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -35,36 +35,54 @@ from tensorflow.python.util import nest class InputIteratorTestBase(test.TestCase): - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) + def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, + devices, split_batch_by): device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = input_lib.InputFunctionIterator( - input_fn, input_workers, input_contexts) + input_contexts = [] + for i in range(input_workers.num_workers): + input_contexts.append( + distribute_lib.InputContext( + num_input_pipelines=input_workers.num_workers, + input_pipeline_id=i, + num_replicas_in_sync=len(devices))) + + iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers, + input_contexts) else: iterator = input_lib.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) + dataset_fn(distribute_lib.InputContext()), input_workers, + split_batch_by) + return iterator + + def _test_iterator(self, + input_type, + dataset_fn, + worker_device_pairs, + expected_values, + sess=None, + split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + iterator = self._create_iterator( + input_type, dataset_fn, worker_device_pairs, devices, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) + evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) @@ -73,7 +91,9 @@ class InputIteratorTestBase(test.TestCase): next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) class InputIteratorSingleWorkerTest(InputIteratorTestBase, @@ -84,7 +104,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, input_type=["input_fn", "dataset"])) def testOneDeviceCPU(self, input_type): worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] @@ -97,7 +117,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTwoDevicesOneGPUOneCPU(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -110,7 +130,9 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTupleDataset(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): + + def dataset_fn(ctx): + del ctx dataset1 = dataset_ops.Dataset.range(10) dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -120,15 +142,17 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) + @combinations.generate( + combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) def testUnevenDatasetBatches(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + # The last global batch only contains data for one replica. + expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) @@ -140,7 +164,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, def testBatchSplitting(self, input_type, split_batch_by): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) updated_batch_size = ( batch_size // split_batch_by if split_batch_by else batch_size) @@ -182,7 +206,7 @@ class InputIteratorMultiWorkerTest( def testOneDevicePerWorker(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) + dataset_fn = lambda _: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 0], [1, 1], [2, 2], [3, 3]], sess) @@ -193,7 +217,7 @@ class InputIteratorMultiWorkerTest( def testTwoDevicesPerWorker(self, input_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) + dataset_fn = lambda _: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 1, 0, 1], [2, 3, 2, 3]], sess) @@ -203,7 +227,9 @@ class InputIteratorMultiWorkerTest( def testTupleDataset(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): + + def dataset_fn(ctx): + del ctx dataset1 = dataset_ops.Dataset.range(4) dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -212,6 +238,36 @@ class InputIteratorMultiWorkerTest( self._test_iterator(input_type, dataset_fn, worker_devices, expected_values, sess) + @combinations.generate( + combinations.combine( + mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], + [[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + @combinations.generate( + combinations.combine( + mode=["graph"], input_type=["input_fn"], required_gpus=1)) + def testDifferentDatasets(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + + def dataset_fn(ctx): + if ctx.input_pipeline_id == 0: + return dataset_ops.Dataset.range(8).batch(2) + else: + return dataset_ops.Dataset.range(9).batch(2) + + expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], + [[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 2204e3df454..51a394859e6 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -21,14 +21,21 @@ from __future__ import print_function from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.data.util import structure from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import values from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest class InputWorkers(object): @@ -126,6 +133,7 @@ class InputIteratorImpl(InputIterator): def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" replicas = [] + worker_has_values = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) @@ -133,8 +141,61 @@ class InputIteratorImpl(InputIterator): else: new_name = None with ops.device(worker): + worker_has_value, next_element = ( + self._iterators[i].get_next_as_list(new_name)) + worker_has_values.append(worker_has_value) # Make `replicas` a flat list of values across all replicas. - replicas.extend(self._iterators[i].get_next_as_list(new_name)) + replicas.append(next_element) + + out_of_range_replicas = [] + + def out_of_range_fn(worker_index, device): + """This function will throw an OutOfRange error.""" + # As this will be only called when there is no data left, so calling + # get_next() will trigger an OutOfRange error. + data = self._iterators[worker_index].get_next(device) + out_of_range_replicas.append(data) + return data + + # `global_has_value` indicates whether there is data in this global batch. + # We do a all-reduce across all the workers in the multi-worker case. + # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy. + if len(worker_has_values) > 1: + with ops.device(self._input_workers.compute_devices_for_worker(0)[0]): + # Place the tf.reduce_any op in device 0 to minimize communication + # cost. + # TODO(b/128545270): Investigate why placing it on worker 0 will cause + # the entire data to copy back from device to host. + global_has_value = math_ops.reduce_any(worker_has_values) + else: + global_has_value = worker_has_values[0] + + results = [] + for i, worker in enumerate(self._input_workers.worker_devices): + with ops.device(worker): + devices = self._input_workers.compute_devices_for_worker(i) + for j, device in enumerate(devices): + with ops.device(device): + # pylint: disable=undefined-loop-variable + # pylint: disable=cell-var-from-loop + # It is fine for the lambda to capture variables from the loop as + # the lambda is executed in the loop as well. + result = control_flow_ops.cond(global_has_value, + lambda: replicas[i][j], + lambda: out_of_range_fn(i, device)) + # pylint: enable=cell-var-from-loop + # pylint: enable=undefined-loop-variable + results.append(result) + replicas = results + + # Some dimensions in `replicas` will become unknown after we conditionally + # return the real tensors or the dummy tensors. We fix the input shapes by + # using the shapes from `out_of_range_replicas` because it is calling + # get_next() inside. + flattened_replicas = nest.flatten(replicas) + for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): + flattened_replicas[i].set_shape(replica_data.get_shape()) + replicas = nest.pack_sequence_as(replicas, flattened_replicas) return values.regroup(self._input_workers.device_map, replicas) @@ -265,6 +326,45 @@ class DatasetIterator(InputIteratorImpl): super(DatasetIterator, self).__init__(input_workers, iterators) +def _dummy_tensor_fn(value_structure): + """A function to create dummy tensors from `value_structure`.""" + + def create_dummy_tensor(feature_shape, feature_type): + """Create a dummy tensor with possible batch dimensions set to 0.""" + + # Ideally we should set the batch dimension to 0, however as in + # DistributionStrategy we don't know the batch dimension, we try to + # guess it as much as possible. If the feature has unknown dimensions, we + # will set them to 0. If the feature shape is already static, we guess the + # first dimension as batch dimension and set it to 0. + dims = [] + for dim in feature_shape.dims: + if dim.value is None: + dims.append(tensor_shape.Dimension(0)) + else: + dims.append(dim) + if feature_shape.is_fully_defined() and dims: + dims[0] = tensor_shape.Dimension(0) + + # Create the dummy tensor. + dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) + return dummy_tensor + + result = [] + # pylint: disable=protected-access + for feature_shape, feature_type in zip(value_structure._flat_shapes, + value_structure._flat_types): + result.append(create_dummy_tensor(feature_shape, feature_type)) + + if isinstance(value_structure, structure.NestedStructure): + result = nest.pack_sequence_as(value_structure._nested_structure, result) + else: + result = result[0] + # pylint: enable=protected-access + + return result + + class _SingleWorkerDatasetIterator(object): """Iterator for a single `tf.data.Dataset`.""" @@ -290,12 +390,51 @@ class _SingleWorkerDatasetIterator(object): self._iterator = multi_device_iterator_ops.MultiDeviceIterator( self._dataset, self._devices) - def get_next_as_list(self, name=None): - """Get next element from the underlying iterator.""" + def get_next(self, device, name=None): + """Get next element for the given device.""" del name with ops.device(self._worker): - data_list = self._iterator.get_next() - return data_list + return self._iterator.get_next(device) + + def get_next_as_list(self, name=None): + """Get next element from underlying iterator. + + If there is no data left, a list of dummy tensors with possible batch + dimensions set to 0 will be returned. + + Args: + name: not used. + + Returns: + A boolean tensor indicates whether there is any data in next element and + the real data as the next element or a list of dummy tensors if no data + left. + """ + del name + with ops.device(self._worker): + data_list = self._iterator.get_next_as_optional() + result = [] + for i, data in enumerate(data_list): + # Place the condition op in the same device as the data so the data + # doesn't need to be sent back to the worker. + with ops.device(self._devices[i]): + # As MultiDeviceIterator will fetch data in order, so we only need to + # check if the first replica has value to see whether there is data + # left for this single worker. + if i == 0: + worker_has_value = data.has_value() + + # pylint: disable=unnecessary-lambda + # pylint: disable=cell-var-from-loop + real_data = control_flow_ops.cond( + data.has_value(), + lambda: data.get_value(), + lambda: _dummy_tensor_fn(data.value_structure)) + result.append(real_data) + # pylint: enable=cell-var-from-loop + # pylint: enable=unnecessary-lambda + + return worker_has_value, result def initialize(self): """Initialze underlying iterator. @@ -334,12 +473,18 @@ class _SingleWorkerCallableIterator(object): self._worker = worker self._devices = devices + def get_next(self, device, name=None): + """Get next element for the given device from the callable.""" + del device, name + with ops.device(self._worker): + return self._fn() + def get_next_as_list(self, name=None): """Get next element from the callable.""" del name with ops.device(self._worker): data_list = [self._fn() for _ in self._devices] - return data_list + return constant_op.constant(True), data_list def initialize(self): # TODO(petebu) Should this throw an exception instead?