diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index 80a1c7bae8f..204f52b034f 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -35,54 +35,36 @@ from tensorflow.python.util import nest class InputIteratorTestBase(test.TestCase): - def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, - devices, split_batch_by): + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": - input_contexts = [] - for i in range(input_workers.num_workers): - input_contexts.append( - distribute_lib.InputContext( - num_input_pipelines=input_workers.num_workers, - input_pipeline_id=i, - num_replicas_in_sync=len(devices))) - - iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers, - input_contexts) + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) else: iterator = input_lib.DatasetIterator( - dataset_fn(distribute_lib.InputContext()), input_workers, - split_batch_by) - return iterator - - def _test_iterator(self, - input_type, - dataset_fn, - worker_device_pairs, - expected_values, - sess=None, - split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - iterator = self._create_iterator( - input_type, dataset_fn, worker_device_pairs, devices, split_batch_by) + dataset_fn(), input_workers, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + self.assertAllEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) @@ -91,9 +73,7 @@ class InputIteratorTestBase(test.TestCase): next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + self.assertAllEqual(expected_value, computed_value) class InputIteratorSingleWorkerTest(InputIteratorTestBase, @@ -104,7 +84,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, input_type=["input_fn", "dataset"])) def testOneDeviceCPU(self, input_type): worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] @@ -117,7 +97,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTwoDevicesOneGPUOneCPU(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -130,9 +110,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTupleDataset(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - - def dataset_fn(ctx): - del ctx + def dataset_fn(): dataset1 = dataset_ops.Dataset.range(10) dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -142,17 +120,15 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) - @combinations.generate( - combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) def testUnevenDatasetBatches(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + dataset_fn = lambda: dataset_ops.Dataset.range(11) - # The last global batch only contains data for one replica. - expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] + expected_values = [[i, i+1] for i in range(0, 10, 2)] self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) @@ -164,7 +140,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, def testBatchSplitting(self, input_type, split_batch_by): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] batch_size = 10 - dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) updated_batch_size = ( batch_size // split_batch_by if split_batch_by else batch_size) @@ -206,7 +182,7 @@ class InputIteratorMultiWorkerTest( def testOneDevicePerWorker(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + dataset_fn = lambda: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 0], [1, 1], [2, 2], [3, 3]], sess) @@ -217,7 +193,7 @@ class InputIteratorMultiWorkerTest( def testTwoDevicesPerWorker(self, input_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + dataset_fn = lambda: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 1, 0, 1], [2, 3, 2, 3]], sess) @@ -227,9 +203,7 @@ class InputIteratorMultiWorkerTest( def testTupleDataset(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - - def dataset_fn(ctx): - del ctx + def dataset_fn(): dataset1 = dataset_ops.Dataset.range(4) dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -238,36 +212,6 @@ class InputIteratorMultiWorkerTest( self._test_iterator(input_type, dataset_fn, worker_devices, expected_values, sess) - @combinations.generate( - combinations.combine( - mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) - expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], - [[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - - @combinations.generate( - combinations.combine( - mode=["graph"], input_type=["input_fn"], required_gpus=1)) - def testDifferentDatasets(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - - def dataset_fn(ctx): - if ctx.input_pipeline_id == 0: - return dataset_ops.Dataset.range(8).batch(2) - else: - return dataset_ops.Dataset.range(9).batch(2) - - expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], - [[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 7aa861dee97..35b1a8bf842 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -21,21 +21,14 @@ from __future__ import print_function from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops -from tensorflow.python.data.util import structure from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import values from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest class InputWorkers(object): @@ -133,7 +126,6 @@ class InputIteratorImpl(InputIterator): def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" replicas = [] - worker_has_values = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) @@ -141,61 +133,8 @@ class InputIteratorImpl(InputIterator): else: new_name = None with ops.device(worker): - worker_has_value, next_element = ( - self._iterators[i].get_next_as_list(new_name)) - worker_has_values.append(worker_has_value) # Make `replicas` a flat list of values across all replicas. - replicas.append(next_element) - - out_of_range_replicas = [] - - def out_of_range_fn(worker_index, device): - """This function will throw an OutOfRange error.""" - # As this will be only called when there is no data left, so calling - # get_next() will trigger an OutOfRange error. - data = self._iterators[worker_index].get_next(device) - out_of_range_replicas.append(data) - return data - - # `global_has_value` indicates whether there is data in this global batch. - # We do a all-reduce across all the workers in the multi-worker case. - # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy. - if len(worker_has_values) > 1: - with ops.device(self._input_workers.compute_devices_for_worker(0)[0]): - # Place the tf.reduce_any op in device 0 to minimize communication - # cost. - # TODO(b/128545270): Investigate why placing it on worker 0 will cause - # the entire data to copy back from device to host. - global_has_value = math_ops.reduce_any(worker_has_values) - else: - global_has_value = worker_has_values[0] - - results = [] - for i, worker in enumerate(self._input_workers.worker_devices): - with ops.device(worker): - devices = self._input_workers.compute_devices_for_worker(i) - for j, device in enumerate(devices): - with ops.device(device): - # pylint: disable=undefined-loop-variable - # pylint: disable=cell-var-from-loop - # It is fine for the lambda to capture variables from the loop as - # the lambda is executed in the loop as well. - result = control_flow_ops.cond(global_has_value, - lambda: replicas[i][j], - lambda: out_of_range_fn(i, device)) - # pylint: enable=cell-var-from-loop - # pylint: enable=undefined-loop-variable - results.append(result) - replicas = results - - # Some dimensions in `replicas` will become unknown after we conditionally - # return the real tensors or the dummy tensors. We fix the input shapes by - # using the shapes from `out_of_range_replicas` because it is calling - # get_next() inside. - flattened_replicas = nest.flatten(replicas) - for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): - flattened_replicas[i].set_shape(replica_data.get_shape()) - replicas = nest.pack_sequence_as(replicas, flattened_replicas) + replicas.extend(self._iterators[i].get_next_as_list(new_name)) return values.regroup(self._input_workers.device_map, replicas) @@ -326,45 +265,6 @@ class DatasetIterator(InputIteratorImpl): super(DatasetIterator, self).__init__(input_workers, iterators) -def _dummy_tensor_fn(value_structure): - """A function to create dummy tensors from `value_structure`.""" - - def create_dummy_tensor(feature_shape, feature_type): - """Create a dummy tensor with possible batch dimensions set to 0.""" - - # Ideally we should set the batch dimension to 0, however as in - # DistributionStrategy we don't know the batch dimension, we try to - # guess it as much as possible. If the feature has unknown dimensions, we - # will set them to 0. If the feature shape is already static, we guess the - # first dimension as batch dimension and set it to 0. - dims = [] - for dim in feature_shape.dims: - if dim.value is None: - dims.append(tensor_shape.Dimension(0)) - else: - dims.append(dim) - if feature_shape.is_fully_defined() and dims: - dims[0] = tensor_shape.Dimension(0) - - # Create the dummy tensor. - dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) - return dummy_tensor - - result = [] - # pylint: disable=protected-access - for feature_shape, feature_type in zip(value_structure._flat_shapes, - value_structure._flat_types): - result.append(create_dummy_tensor(feature_shape, feature_type)) - - if isinstance(value_structure, structure.NestedStructure): - result = nest.pack_sequence_as(value_structure._nested_structure, result) - else: - result = result[0] - # pylint: enable=protected-access - - return result - - class _SingleWorkerDatasetIterator(object): """Iterator for a single `tf.data.Dataset`.""" @@ -390,51 +290,12 @@ class _SingleWorkerDatasetIterator(object): self._iterator = multi_device_iterator_ops.MultiDeviceIterator( self._dataset, self._devices) - def get_next(self, device, name=None): - """Get next element for the given device.""" - del name - with ops.device(self._worker): - return self._iterator.get_next(device) - def get_next_as_list(self, name=None): - """Get next element from underlying iterator. - - If there is no data left, a list of dummy tensors with possible batch - dimensions set to 0 will be returned. - - Args: - name: not used. - - Returns: - A boolean tensor indicates whether there is any data in next element and - the real data as the next element or a list of dummy tensors if no data - left. - """ + """Get next element from the underlying iterator.""" del name with ops.device(self._worker): - data_list = self._iterator.get_next_as_optional() - result = [] - for i, data in enumerate(data_list): - # Place the condition op in the same device as the data so the data - # doesn't need to be sent back to the worker. - with ops.device(self._devices[i]): - # As MultiDeviceIterator will fetch data in order, so we only need to - # check if the first replica has value to see whether there is data - # left for this single worker. - if i == 0: - worker_has_value = data.has_value() - - # pylint: disable=unnecessary-lambda - # pylint: disable=cell-var-from-loop - real_data = control_flow_ops.cond( - data.has_value(), - lambda: data.get_value(), - lambda: _dummy_tensor_fn(data.value_structure)) - result.append(real_data) - # pylint: enable=cell-var-from-loop - # pylint: enable=unnecessary-lambda - - return worker_has_value, result + data_list = self._iterator.get_next() + return data_list def initialize(self): """Initialze underlying iterator. @@ -473,18 +334,12 @@ class _SingleWorkerCallableIterator(object): self._worker = worker self._devices = devices - def get_next(self, device, name=None): - """Get next element for the given device from the callable.""" - del device, name - with ops.device(self._worker): - return self._fn() - def get_next_as_list(self, name=None): """Get next element from the callable.""" del name with ops.device(self._worker): data_list = [self._fn() for _ in self._devices] - return constant_op.constant(True), data_list + return data_list def initialize(self): # TODO(petebu) Should this throw an exception instead?