From d2dd369f9dadb5dd3220ababa299ad89cd8e8574 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Dec 2018 08:46:45 -0800 Subject: [PATCH] Split input-related classes out of python/distribute/values.py into new file .../input_lib.py. PiperOrigin-RevId: 227227637 --- tensorflow/contrib/distribute/python/BUILD | 31 +- .../python/collective_all_reduce_strategy.py | 11 +- .../distribute/python/input_lib_test.py | 480 ++++++++++++ .../distribute/python/mirrored_strategy.py | 8 +- .../distribute/python/one_device_strategy.py | 12 +- .../python/parameter_server_strategy.py | 13 +- .../contrib/distribute/python/tpu_strategy.py | 12 +- .../contrib/distribute/python/values_test.py | 446 ----------- tensorflow/python/distribute/BUILD | 20 +- tensorflow/python/distribute/input_lib.py | 707 ++++++++++++++++++ .../python/distribute/mirrored_strategy.py | 16 +- tensorflow/python/distribute/values.py | 678 ----------------- 12 files changed, 1272 insertions(+), 1162 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/input_lib_test.py create mode 100644 tensorflow/python/distribute/input_lib.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 2d6a08df9a2..f27224e46e4 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -23,17 +23,14 @@ cuda_py_test( additional_deps = [ ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -45,14 +42,36 @@ cuda_py_test( ], ) +cuda_py_test( + name = "input_lib_test", + srcs = ["input_lib_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", + ], +) + py_library( name = "mirrored_strategy", srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", ], ) @@ -69,6 +88,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", @@ -119,6 +139,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -139,6 +160,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:cross_device_utils", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -289,6 +311,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 12197c3d0de..f6361cb6e89 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -26,6 +26,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import values from tensorflow.python.eager import context @@ -130,7 +131,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._input_workers = values.InputWorkers( + self._input_workers = input_lib.InputWorkers( self._device_map, [(self._worker_device, self.worker_devices)]) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, @@ -229,13 +230,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. worker_index = 0 - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, prefetch_on_device=True) def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -252,7 +253,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [input_context]) def _configure(self, diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py new file mode 100644 index 00000000000..f589cd6ad54 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -0,0 +1,480 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.util import nest + + +class PerReplicaDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator(self, devices, dataset, expected_values): + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map) + per_replica_dataset = input_lib.PerReplicaDataset(dataset, input_workers, 0) + if context.executing_eagerly(): + iterator = per_replica_dataset.make_one_shot_iterator() + else: + iterator = per_replica_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) + + for expected_value in expected_values: + next_element = iterator.get_next_as_list() + computed_value = self.evaluate(next_element) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next_as_list() + self.evaluate(next_element) + + @test_util.run_in_graph_and_eager_modes + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + def testInitializableIterator(self): + with context.graph_mode(): + devices = ["/device:CPU:0"] + # Using random input since that is only allowed with initializable + # iterator. + dataset = dataset_ops.Dataset.from_tensor_slices( + random_ops.random_uniform((10,))) + + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map) + per_replica_dataset = input_lib.PerReplicaDataset( + dataset, input_workers, 0) + iterator = per_replica_dataset.make_initializable_iterator() + + self.evaluate(iterator.initializer) + next_element = iterator.get_next_as_list() + for _ in range(10): + self.evaluate(next_element) + + # Should fail after the input is finished. + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next_element) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(iterator.initializer) + for _ in range(10): + self.evaluate(next_element) + + +class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): + + def _test_iterator(self, sess, iterator, devices, expected_values): + next_element = iterator.get_next() + for r, device in enumerate(devices): + v = values.select_replica(r, next_element) + # The `v` here can be a tuple. + for element in nest.flatten(v): + self.assertTrue(element.device in device) + + for expected_value in expected_values: + t = [values.select_replica(r, next_element) for r in range(len(devices))] + actual = sess.run(t) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values): + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_devices) + multi_worker_dataset = input_lib.MultiWorkerDataset( + dataset_fn, input_workers) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) + + def _cpu_devices(self): + worker_devices = ( + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"]) + ) + devices = [ + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_devices, devices + + def _cpu_and_one_gpu_devices(self): + worker_devices = ( + ("/job:worker/replica:0/task:0", ( + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + )), + ("/job:worker/replica:0/task:1", ( + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + )) + ) + devices = [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_devices, devices + + def testDataDistributionOneDevicePerWorker(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + def testDataDistributionTwoDevicePerWorker(self): + if context.num_gpus() < 1: + self.skipTest("A GPU is not available for this test.") + worker_devices, devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) + + def testTupleDataset(self): + worker_devices, devices = self._cpu_devices() + + with context.graph_mode(): + + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(8) + dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] + self._test_dataset(dataset_fn, worker_devices, devices, + expected_values) + + def testInitializableIterator(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(8) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_devices) + multi_worker_dataset = input_lib.MultiWorkerDataset( + dataset_fn, input_workers) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + + sess.run(multi_worker_iterator.initializer) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(multi_worker_iterator.initializer) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + + def testValueErrorForIterator(self): + # Incompatiable arguments. + d1 = "/device:GPU:0" + d2 = "/device:GPU:1" + device_map = values.ReplicaDeviceMap([d1, d2]) + input_workers = input_lib.InputWorkers( + device_map, (("w1", (d1,)), ("w2", (d2,)))) + with self.assertRaises(ValueError): + input_lib.MultiWorkerDataIterator([("w1", None)], input_workers) + + def testDuplicateDevices(self): + _, devices = self._cpu_devices() + devices.append("/job:worker/replica:0/task:0/device:CPU:0") + with self.assertRaises(ValueError): + _ = values.ReplicaDeviceMap(devices) + + +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) + else: + iterator = input_lib.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class SplitDatasetBatchTest(test.TestCase): + + def testBatchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testMapAndBatchDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testPrefetchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 71e50b83b07..db8fd983078 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -21,8 +21,8 @@ from __future__ import print_function import functools from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name @@ -135,14 +135,14 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - return values.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): if self._local_mode: - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) else: - return values.MultiWorkerDataset( + return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers, auto_shard=self._auto_shard_dataset) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 700751d68c5..fb470f8546f 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -52,7 +53,8 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): worker = device_util.canonicalize("/device:CPU:0") worker_device_pairs = [(worker, [self._device])] device_map = values.SingleDeviceMap(device) - self._input_workers = values.InputWorkers(device_map, worker_device_pairs) + self._input_workers = input_lib.InputWorkers( + device_map, worker_device_pairs) def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) @@ -67,17 +69,17 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" - return values.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [distribute_lib.InputContext()]) def _broadcast_to(self, tensor, destinations): @@ -91,7 +93,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index a6e924b509f..461e1bca21e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import values from tensorflow.python.eager import context @@ -153,7 +154,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): compute_devices = (worker_device,) self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( + self._input_workers = input_lib.InputWorkers( self._device_map, [(worker_device, compute_devices)]) # In distributed mode, place variables on ps jobs in a round-robin fashion. @@ -210,7 +211,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): compute_devices = (_LOCAL_CPU,) self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( + self._input_workers = input_lib.InputWorkers( self._device_map, [(worker_device, compute_devices)]) # If there is only one GPU, put everything on that GPU. Otherwise, place @@ -237,13 +238,13 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0, prefetch_on_device=True) def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -262,7 +263,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [input_context]) def _broadcast_to(self, tensor, destinations): diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 89b48d3f13e..10b7ef04072 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -33,6 +33,7 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib @@ -204,7 +205,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) ] - self._input_workers = values.InputWorkers(input_device_map, worker_devices) + self._input_workers = input_lib.InputWorkers( + input_device_map, worker_devices) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -304,11 +306,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _distribute_dataset(self, dataset_fn): - return values.MultiWorkerDataset( + return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers) @@ -339,7 +341,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 73efb524b93..51c58b0b2f3 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -22,28 +22,20 @@ import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.util import nest class DistributedValuesTest(test.TestCase): @@ -354,444 +346,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next_as_list() - computed_value = self.evaluate(next_element) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next_as_list() - self.evaluate(next_element) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next_as_list() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for r, device in enumerate(devices): - v = values.select_replica(r, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - t = [values.select_replica(r, next_element) for r in range(len(devices))] - actual = sess.run(t) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"]) - ) - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", ( - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - )), - ("/job:worker/replica:0/task:1", ( - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - )) - ) - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - d1 = "/device:GPU:0" - d2 = "/device:GPU:1" - device_map = values.ReplicaDeviceMap([d1, d2]) - input_workers = values.InputWorkers( - device_map, (("w1", (d1,)), ("w2", (d2,)))) - with self.assertRaises(ValueError): - values.MultiWorkerDataIterator([("w1", None)], input_workers) - - def testDuplicateDevices(self): - _, devices = self._cpu_devices() - devices.append("/job:worker/replica:0/task:0/device:CPU:0") - with self.assertRaises(ValueError): - _ = values.ReplicaDeviceMap(devices) - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_device_pairs) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator( - input_fn, input_workers, input_contexts) - else: - iterator = values.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - - -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 02957b2fefb..987fb00454c 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -219,6 +219,7 @@ py_library( ":cross_device_ops", ":device_util", ":distribute_lib", + ":input_lib", ":multi_worker_util", ":reduce_util", ":shared_variable_creator", @@ -253,6 +254,23 @@ py_library( ], ) +py_library( + name = "input_lib", + srcs = ["input_lib.py"], + deps = [ + ":device_util", + ":distribute_lib", + ":input_ops", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:multi_device_iterator_ops", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "input_ops", srcs = ["input_ops.py"], @@ -348,14 +366,12 @@ py_library( deps = [ ":device_util", ":distribute_lib", - ":input_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py new file mode 100644 index 00000000000..cbe6518e5cb --- /dev/null +++ b/tensorflow/python/distribute/input_lib.py @@ -0,0 +1,707 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing distributed inputs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import input_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class InputWorkers(object): + """A 1-to-many mapping from input worker devices to compute devices.""" + + def __init__(self, device_map, worker_device_pairs=None, logical_device=0): + """Initialize an `InputWorkers` object. + + Args: + device_map: A `DeviceMap` with the computation devices fed by the + input workers. + worker_device_pairs: A sequence of pairs: + `(input device, a tuple of compute devices fed by that input device)`. + logical_device: The logical device of `device_map` to feed. + """ + self._device_map = device_map + self._logical_device = logical_device + if worker_device_pairs is None: + worker_device_pairs = (( + device_util.canonicalize("/device:CPU:0"), + device_map.logical_to_actual_devices(logical_device)),) + self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) + self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) + for _, f in worker_device_pairs) + flattened = tuple(d for l in self._fed_devices for d in l) + assert (flattened == + device_map.logical_to_actual_devices(logical_device)), ( + "flattened: %s logical device %d: %s" % + (flattened, logical_device, + device_map.logical_to_actual_devices(logical_device))) + + @property + def device_map(self): + return self._device_map + + @property + def logical_device(self): + return self._logical_device + + @property + def num_workers(self): + return len(self._input_worker_devices) + + @property + def worker_devices(self): + return self._input_worker_devices + + def compute_devices_for_worker(self, worker_index): + return self._fed_devices[worker_index] + + def __repr__(self): + devices = self.worker_devices + debug_repr = ",\n".join(" %d %s: %s" % + (i, devices[i], self._fed_devices[i]) + for i in range(len(devices))) + return "%s:{\n%s\n device_map: %s}" % ( + self.__class__.__name__, debug_repr, self._device_map) + + +class PerReplicaDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`.""" + + def __init__(self, iterator, input_workers, worker_index, prefetch_on_device): + assert isinstance(input_workers, InputWorkers) + self._iterator = iterator + self._input_workers = input_workers + self._worker_index = worker_index + self._prefetch_on_device = prefetch_on_device + + @property + def initializer(self): + return self._iterator.initializer + + def get_next_as_list(self, name=None): + """Scatter the input across devices.""" + if self._prefetch_on_device: + data_list = self._iterator.get_next() + else: + batch = self._iterator.get_next(name=name) + data_list = [] + def get_ith(i): + return lambda x: x[i] + + devices = self._input_workers.compute_devices_for_worker( + self._worker_index) + for i, d in enumerate(devices): + v = nest.map_structure(get_ith(i), batch) + if context.executing_eagerly(): + with ops.device(d): + v = nest.map_structure(array_ops.identity, v) + data_list.append(v) + + return data_list + + def get_next(self, name=None): + assert self._input_workers.num_workers == 1 + data_list = self.get_next_as_list(name) + return values.regroup(self._input_workers.device_map, data_list) + + @property + def output_classes(self): + return self._iterator.output_classes + + @property + def output_shapes(self): + return self._iterator.output_shapes + + @property + def output_types(self): + return self._iterator.output_types + + +class PerReplicaDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerReplica` data.""" + + def __init__(self, dataset, input_workers, worker_index, + prefetch_on_device=None): + assert isinstance(input_workers, InputWorkers) + assert worker_index is not None + assert worker_index is not True # pylint: disable=g-bool-id-comparison + assert worker_index is not False # pylint: disable=g-bool-id-comparison + self._input_workers = input_workers + self._worker_index = worker_index + + # Default to using prefetching, unless specified. + self._prefetch_on_device = prefetch_on_device + if self._prefetch_on_device is None: + self._prefetch_on_device = True + + self._dataset = dataset + if not self._prefetch_on_device: + # TODO(priyag): If dropping remainder is not appropriate, find another + # approach to distributing the dataset when not possible to divide evenly. + # Possibly not an issue when we start using PartitionedDataset. + num_replicas = len( + self._input_workers.compute_devices_for_worker(self._worker_index)) + self._dataset = self._dataset.batch(num_replicas, drop_remainder=True) + else: + self._replica_devices = self._input_workers.compute_devices_for_worker( + self._worker_index) + + def make_one_shot_iterator(self): + """Get a one time use iterator for the distributed PerReplicaDataset.""" + # Graph mode with one shot iterator is disabled. + if not context.executing_eagerly(): + raise ValueError("Cannot create a one shot iterator. Please use " + "`make_initializable_iterator()` instead.") + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._replica_devices) + else: + dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset) + return PerReplicaDataIterator( + dataset_iterator, + self._input_workers, + self._worker_index, + prefetch_on_device=self._prefetch_on_device) + + def make_initializable_iterator(self): + """Get an initializable iterator for the distributed PerReplicaDataset.""" + # Eager mode generates already initialized iterators. Hence we cannot create + # an initializable iterator. + if context.executing_eagerly(): + raise ValueError("Cannot create initializable iterator in Eager mode. " + "Please use `make_one_shot_iterator` instead.") + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._replica_devices) + else: + dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset) + return PerReplicaDataIterator( + dataset_iterator, self._input_workers, self._worker_index, + prefetch_on_device=self._prefetch_on_device) + + +class MultiWorkerDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" + + def __init__(self, iterators, input_workers): + """Initialize the `MultiWorkerDataIterator` object. + + Args: + iterators: a list of worker, iterator pairs. + input_workers: an `InputWorkers` object. + + Raises: + ValueError: if iterators and input_workers are not compatible. + """ + assert isinstance(input_workers, InputWorkers) + workers = tuple(d for d, _ in iterators) + if workers != input_workers.worker_devices: + raise ValueError("iterators and input_workers are not compatible. " + "iterator workers: %r input_workers devices: %r" % + (workers, input_workers.worker_devices)) + self._iterators = tuple(i for _, i in iterators) + self._input_workers = input_workers + + @property + def initializer(self): + return control_flow_ops.group( + tuple(iterator.initializer for iterator in self._iterators)) + + def get_iterator(self, worker): + for i, w in enumerate(self._input_workers.worker_devices): + if worker == w: + return self._iterators[i] + return None + + @property + def output_shapes(self): + return self._iterators[0].output_shapes + + @property + def output_types(self): + return self._iterators[0].output_types + + def get_next(self, name=None): + """Scatter the input across hosts and devices.""" + replicas = [] + for worker, iterator in zip(self._input_workers.worker_devices, + self._iterators): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + data_per_worker = iterator.get_next_as_list(name=new_name) + # Append to replicas to get a flat list of values indexed by replica. + replicas.extend(data_per_worker) + + return values.regroup(self._input_workers.device_map, replicas) + + +class MultiWorkerDataset(object): + """Like a `tf.data.Dataset` that distributes data to different workers. + + Each worker gets one shard of the input dataset. This currently does not work + in eager mode. + """ + + def __init__(self, dataset_fn, input_workers, prefetch_on_device=None, + auto_shard=False): + """Initialize the MultiWorkerDataset object. + + Args: + dataset_fn: a function or a list of functions that returns a + `tf.data.Dataset`. + input_workers: an `InputWorkers` object. + prefetch_on_device: whether to prefetch to devices. + auto_shard: whether to auto-shard the dataset. + """ + assert isinstance(input_workers, InputWorkers) + if isinstance(dataset_fn, (list, tuple)): + if len(dataset_fn) != input_workers.num_workers: + raise ValueError("If `dataset_fn` is a list, it must have one entry " + "per worker") + # TODO(rohanj): b/120673685 to track re-enabling auto sharding. + if auto_shard: + raise ValueError("Currently autosharding is not supported.") + self._input_workers = input_workers + self._datasets = [] + # TODO(yuefengz, priyag): support different set of jobs for input + # processing. + for i, worker in enumerate(input_workers.worker_devices): + with ops.device(worker): + if isinstance(dataset_fn, (list, tuple)): + worker_input = dataset_fn[i]() + else: + worker_input = dataset_fn() + dataset = PerReplicaDataset(worker_input, input_workers, i, + prefetch_on_device=prefetch_on_device) + self._datasets.append((worker, dataset)) + + def make_one_shot_iterator(self): + iterators = [] + for worker, dataset in self._datasets: + with ops.device(worker): + iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset))) + return MultiWorkerDataIterator(iterators, self._input_workers) + + def make_initializable_iterator(self): + iterators = [] + for worker, dataset in self._datasets: + with ops.device(worker): + iterators.append( + (worker, dataset_ops.make_initializable_iterator(dataset))) + return MultiWorkerDataIterator(iterators, self._input_workers) + + +class InputIterator(object): + """An input iterator, intended to be passed to `DistributionStrategy.run`.""" + + def get_next(self): + """Returns the next inputs for all replicas.""" + raise NotImplementedError("must be implemented in descendants") + + def initialize(self): + """Initialize the underlying input dataset, when applicable. + + In eager mode, this will create a new iterator and return it. + In graph mode, this will initialize the same underlying iterator(s). + + Users are required to call this if + - This iterator was returned from a call to `make_input_fn_iterator` with an + input function that returns a dataset. + - Or this iterator was returned from a call to `make_dataset_iterator`. + + Returns: + A list of initialization ops to be executed. + """ + raise NotImplementedError("must be implemented in descendants") + + +class InputIteratorImpl(InputIterator): + """Common implementation for all input iterators.""" + + def __init__(self, input_workers, iterators): + assert isinstance(input_workers, InputWorkers) + if not input_workers.worker_devices: + raise ValueError("Should have at least one worker for input iterator.") + + self._iterators = iterators + self._input_workers = input_workers + + def get_next(self, name=None): + """Returns the next input from the iterator for all replicas.""" + replicas = [] + for i, worker in enumerate(self._input_workers.worker_devices): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + # Make `replicas` a flat list of values across all replicas. + replicas.extend(self._iterators[i].get_next_as_list(new_name)) + + return values.regroup(self._input_workers.device_map, replicas) + + def initialize(self): + """Initialze underlying iterators. + + Returns: + A list of any initializer ops that should be run. + """ + init_ops = [] + for it in self._iterators: + init_ops.extend(it.initialize()) + return init_ops + + # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. + @property + def output_classes(self): + return self._iterators[0].output_classes + + # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. + @property + def output_shapes(self): + return self._iterators[0].output_shapes + + # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. + @property + def output_types(self): + return self._iterators[0].output_types + + # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. + def get_iterator(self, worker): + for i, w in enumerate(self._input_workers.worker_devices): + if worker == w: + return self._iterators[i] + return None + + +class InputFunctionIterator(InputIteratorImpl): + """Iterator created from input function.""" + + def __init__(self, input_fn, input_workers, input_contexts): + """Make an iterator for input provided via an input function. + + Currently implements PER_WORKER mode, in which the `input_fn` is called + once on each worker. + + TODO(priyag): Add other replication modes. + TODO(priyag): Allow taking input function that returns a callable that + returns nest of tensors. + + Args: + input_fn: Input function that returns a `tf.data.Dataset` object. + input_workers: an `InputWorkers` object. + input_contexts: A list of `InputContext` instances to be passed to call(s) + to `input_fn`. Length and order should match worker order in + `worker_device_pairs`. + """ + assert isinstance(input_workers, InputWorkers) + if input_workers.num_workers != len(input_contexts): + raise ValueError( + "Number of input workers (%d) is not same as number of " + "input_contexts (%d)" % + (input_workers.num_workers, len(input_contexts))) + + iterators = [] + for i, ctx in enumerate(input_contexts): + worker = input_workers.worker_devices[i] + with ops.device(worker): + result = input_fn(ctx) + if not isinstance(result, dataset_ops.DatasetV2): + raise ValueError("input_fn must return a tf.data.Dataset.") + devices = input_workers.compute_devices_for_worker(i) + iterator = _SingleWorkerDatasetIterator(result, worker, devices) + iterators.append(iterator) + + super(InputFunctionIterator, self).__init__(input_workers, iterators) + + +class DatasetIterator(InputIteratorImpl): + """Iterator created from input dataset.""" + + def __init__(self, dataset, input_workers, split_batch_by=None): + """Make an iterator for the dataset on given devices. + + If `split_batch_by` is not None, we "split" each batch of the + dataset by `split_batch_by` value. To achieve this, we first unbatch the + input dataset and then rebatch it with the per replica batch size that is + calculated using `global_batch_size // split_batch_by`. + The currently supported datasets are as follows: + `dataset.batch()` is the last operation on the dataset OR + `dataset.apply(map_and_batch)` is the last operation on the dataset OR + `dataset.batch().prefetch()` are the last 2 operations on the dataset OR + `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. + + TODO(priyag): Support multi worker / host cases properly by cloning + and sharding the dataset on each worker. Current setup will only work in + some cases, such as in-graph multi worker GPU case. If the input pipeline + has random shuffling (with a different seed on each worker), each worker + will see random input from the same overall dataset in each step. Otherwise, + each worker will see the same input in each step. + + Args: + dataset: `tf.data.Dataset` that will be used as the input source. + input_workers: an `InputWorkers` object. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. + """ + assert isinstance(input_workers, InputWorkers) + if split_batch_by: + dataset = _split_dataset_batch(dataset, split_batch_by) + + iterators = [] + for i, worker in enumerate(input_workers.worker_devices): + with ops.device(worker): + worker_devices = input_workers.compute_devices_for_worker(i) + cloned_dataset = dataset + if not context.executing_eagerly(): + cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access + iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, + worker_devices) + iterators.append(iterator) + + super(DatasetIterator, self).__init__(input_workers, iterators) + + +class _SingleWorkerDatasetIterator(object): + """Iterator for a single `tf.data.Dataset`.""" + + def __init__(self, dataset, worker, devices): + """Create iterator for the `dataset` to fetch data to worker's `devices` . + + `MultiDeviceIterator` is used to prefetch input to the devices on the + given worker. + + Args: + dataset: A `tf.data.Dataset` instance. + worker: Worker on which ops should be created. + devices: Distribute data from `dataset` to these devices. + """ + self._dataset = dataset + self._worker = worker + self._devices = devices + self._make_iterator() + + def _make_iterator(self): + """Make appropriate iterator on the dataset.""" + with ops.device(self._worker): + self._iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices) + + def get_next_as_list(self, name=None): + """Get next element from the underlying iterator.""" + del name + with ops.device(self._worker): + data_list = self._iterator.get_next() + return data_list + + def initialize(self): + """Initialze underlying iterator. + + In eager execution, this simply recreates the underlying iterator. + In graph execution, it returns the initializer ops for the underlying + iterator. + + Returns: + A list of any initializer ops that should be run. + """ + if context.executing_eagerly(): + self._make_iterator() + return [] + else: + return [self._iterator.initializer] + + @property + def output_classes(self): + return self._iterator.output_classes + + @property + def output_shapes(self): + return self._iterator.output_shapes + + @property + def output_types(self): + return self._iterator.output_types + + +def _split_dataset_batch(dataset, split_batch_by): + """Divide a batch-ed dataset's batches into smaller batches.""" + # TODO(sourabhbajaj): Remove this in lieu of distributed datasets + # pylint: disable=protected-access + def _get_batch_dataset(d): + """Get the underlying batch dataset from the dataset object.""" + if isinstance(d, dataset_ops.DatasetV1Adapter): + d = d._dataset + + if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): + return d + elif isinstance(d, dataset_ops.PrefetchDataset): + return _get_batch_dataset(d._input_dataset) + raise ValueError( + "Unable to get batched dataset from the input dataset. `batch` " + "`map_and_batch` need to be the last operations on the dataset. " + "The batch operations can be followed by a prefetch.") + + batched_dataset = _get_batch_dataset(dataset) + if isinstance(batched_dataset, dataset_ops.BatchDataset): + batch_size = batched_dataset._batch_size + drop_remainder = batched_dataset._drop_remainder + elif isinstance(batched_dataset, batching._MapAndBatchDataset): + batch_size = batched_dataset._batch_size_t + drop_remainder = batched_dataset._drop_remainder_t + + prefetch_buffer = None + if isinstance(dataset, dataset_ops.PrefetchDataset): + prefetch_buffer = dataset._buffer_size + elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) + and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): + prefetch_buffer = dataset._dataset._buffer_size + # pylint: enable=protected-access + + if tensor_util.is_tensor(batch_size): + batch_size = tensor_util.constant_value(batch_size) + + if tensor_util.is_tensor(drop_remainder): + drop_remainder = tensor_util.constant_value(drop_remainder) + + if batch_size % split_batch_by: + raise ValueError( + "Batch size %s cannot be sharded evenly across replicas %s" % ( + batch_size, split_batch_by)) + new_batch_size = batch_size // split_batch_by + + dataset = dataset.apply(batching.unbatch()) + dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder) + if prefetch_buffer is not None: + dataset = dataset.prefetch(prefetch_buffer) + return dataset + + +class MultiStepContext(object): + """A context object that can be used to capture things when running steps. + + This context object is useful when running multiple steps at a time using the + `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step + function to specify which outputs to emit at what frequency. Currently it + supports capturing output from the last step, as well as capturing non tensor + outputs. In the future it will be augmented to support other use cases such + as output each N steps. + """ + + def __init__(self): + """Initialize an output context. + + Returns: + A context object. + """ + self._last_step_outputs = {} + self._last_step_outputs_reduce_ops = {} + self._non_tensor_outputs = {} + + @property + def last_step_outputs(self): + """A dictionary consisting of outputs to be captured on last step. + + Keys in the dictionary are names of tensors to be captured, as specified + when `set_last_step_output` is called. + Values in the dictionary are the tensors themselves. If + `set_last_step_output` was called with a `reduce_op` for this output, + then the value is the reduced value. + + Returns: + A dictionary with last step outputs. + """ + return self._last_step_outputs + + def _set_last_step_outputs(self, outputs): + """Replace the entire dictionary of last step outputs.""" + if not isinstance(outputs, dict): + raise ValueError("Need a dictionary to set last_step_outputs.") + self._last_step_outputs = outputs + + def set_last_step_output(self, name, output, reduce_op=None): + """Set `output` with `name` to be outputted from the last step. + + Args: + name: String, name to identify the output. Doesn't need to match tensor + name. + output: The tensors that should be outputted with `name`. See below for + actual types supported. + reduce_op: Reduction method to use to reduce outputs from multiple + replicas. Required if `set_last_step_output` is called in a replica + context. Optional in cross_replica_context. + When present, the outputs from all the replicas are reduced using the + current distribution strategy's `reduce` method. Hence, the type of + `output` must be what's supported by the corresponding `reduce` method. + For e.g. if using MirroredStrategy and reduction is set, output + must be a `PerReplica` value. + The reduce method is also recorded in a dictionary + `_last_step_outputs_reduce_ops` for later interpreting of the + outputs as already reduced or not. + """ + if distribution_strategy_context.in_cross_replica_context(): + self._last_step_outputs_reduce_ops[name] = reduce_op + if reduce_op is None: + self._last_step_outputs[name] = output + else: + distribution = distribution_strategy_context.get_distribution_strategy() + self._last_step_outputs[name] = distribution.reduce(reduce_op, output) + else: + assert reduce_op is not None + def merge_fn(distribution, value): + self._last_step_outputs[name] = distribution.reduce(reduce_op, value) + # Setting this inside the `merge_fn` because all replicas share the same + # context object, so it's more robust to set it only once (even if all + # the replicas are trying to set the same value). + self._last_step_outputs_reduce_ops[name] = reduce_op + + distribution_strategy_context.get_replica_context().merge_call( + merge_fn, args=(output,)) + + @property + def non_tensor_outputs(self): + """A dictionary consisting of any non tensor outputs to be captured.""" + return self._non_tensor_outputs + + def set_non_tensor_output(self, name, output): + """Set `output` with `name` to be captured as a non tensor output.""" + if distribution_strategy_context.in_cross_replica_context(): + self._non_tensor_outputs[name] = output + else: + def merge_fn(distribution, value): + # NOTE(priyag): For non tensor outputs, we simply return all the values + # in a list as reduction doesn't make sense on non tensors. + self._non_tensor_outputs[name] = distribution.unwrap(value) + distribution_strategy_context.get_replica_context().merge_call( + merge_fn, args=(output,)) diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 601eafbb5ea..37b493d0f70 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import shared_variable_creator @@ -456,7 +457,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): "No duplicates allowed in `devices` argument: %s" % devices) # TODO(josh11b): Require at least 2 devices? self._device_map = values.ReplicaDeviceMap(devices) - self._input_workers = values.InputWorkers(self._device_map) + self._input_workers = input_lib.InputWorkers(self._device_map) self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best( devices) @@ -489,7 +490,8 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): self._default_device = workers[0] self._device_map = values.ReplicaDeviceMap(devices) - self._input_workers = values.InputWorkers(self._device_map, worker_devices) + self._input_workers = input_lib.InputWorkers( + self._device_map, worker_devices) self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce( workers, _infer_num_gpus_per_worker(devices)) @@ -543,16 +545,16 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): def _distribute_dataset(self, dataset_fn): if self._local_mode: worker_index = 0 - return values.PerReplicaDataset( + return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, worker_index) else: - return values.MultiWorkerDataset( + return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers, auto_shard=False) def _make_dataset_iterator(self, dataset): - return values.DatasetIterator( + return input_lib.DatasetIterator( dataset, self._input_workers, self._num_replicas_in_sync) def _make_input_fn_iterator( @@ -566,7 +568,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, input_contexts) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. @@ -576,7 +578,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 1f5077a75ae..a9dcabdab60 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -23,17 +23,12 @@ import contextlib import weakref import six -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import tape -from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -1409,679 +1404,6 @@ def update_regroup(extended, device_map, updates, group): return nest.pack_sequence_as(regrouped, grouped_flat) -class InputWorkers(object): - """A 1-to-many mapping from input worker devices to compute devices.""" - - def __init__(self, device_map, worker_device_pairs=None, logical_device=0): - """Initialize an `InputWorkers` object. - - Args: - device_map: A `DeviceMap` with the computation devices fed by the - input workers. - worker_device_pairs: A sequence of pairs: - `(input device, a tuple of compute devices fed by that input device)`. - logical_device: The logical device of `device_map` to feed. - """ - self._device_map = device_map - self._logical_device = logical_device - if worker_device_pairs is None: - worker_device_pairs = (( - device_util.canonicalize("/device:CPU:0"), - device_map.logical_to_actual_devices(logical_device)),) - self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) - self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) - for _, f in worker_device_pairs) - flattened = tuple(d for l in self._fed_devices for d in l) - assert (flattened == - device_map.logical_to_actual_devices(logical_device)), ( - "flattened: %s logical device %d: %s" % - (flattened, logical_device, - device_map.logical_to_actual_devices(logical_device))) - - @property - def device_map(self): - return self._device_map - - @property - def logical_device(self): - return self._logical_device - - @property - def num_workers(self): - return len(self._input_worker_devices) - - @property - def worker_devices(self): - return self._input_worker_devices - - def compute_devices_for_worker(self, worker_index): - return self._fed_devices[worker_index] - - def __repr__(self): - devices = self.worker_devices - debug_repr = ",\n".join(" %d %s: %s" % - (i, devices[i], self._fed_devices[i]) - for i in range(len(devices))) - return "%s:{\n%s\n device_map: %s}" % ( - self.__class__.__name__, debug_repr, self._device_map) - - -class PerReplicaDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`.""" - - def __init__(self, iterator, input_workers, worker_index, prefetch_on_device): - assert isinstance(input_workers, InputWorkers) - self._iterator = iterator - self._input_workers = input_workers - self._worker_index = worker_index - self._prefetch_on_device = prefetch_on_device - - @property - def initializer(self): - return self._iterator.initializer - - def get_next_as_list(self, name=None): - """Scatter the input across devices.""" - if self._prefetch_on_device: - data_list = self._iterator.get_next() - else: - batch = self._iterator.get_next(name=name) - data_list = [] - def get_ith(i): - return lambda x: x[i] - - devices = self._input_workers.compute_devices_for_worker( - self._worker_index) - for i, d in enumerate(devices): - v = nest.map_structure(get_ith(i), batch) - if context.executing_eagerly(): - with ops.device(d): - v = nest.map_structure(array_ops.identity, v) - data_list.append(v) - - return data_list - - def get_next(self, name=None): - assert self._input_workers.num_workers == 1 - data_list = self.get_next_as_list(name) - return regroup(self._input_workers.device_map, data_list) - - @property - def output_classes(self): - return self._iterator.output_classes - - @property - def output_shapes(self): - return self._iterator.output_shapes - - @property - def output_types(self): - return self._iterator.output_types - - -class PerReplicaDataset(object): - """Like `tf.data.Dataset` split devices, producing `PerReplica` data.""" - - def __init__(self, dataset, input_workers, worker_index, - prefetch_on_device=None): - assert isinstance(input_workers, InputWorkers) - assert worker_index is not None - assert worker_index is not True - assert worker_index is not False - self._input_workers = input_workers - self._worker_index = worker_index - - # Default to using prefetching, unless specified. - self._prefetch_on_device = prefetch_on_device - if self._prefetch_on_device is None: - self._prefetch_on_device = True - - self._dataset = dataset - if not self._prefetch_on_device: - # TODO(priyag): If dropping remainder is not appropriate, find another - # approach to distributing the dataset when not possible to divide evenly. - # Possibly not an issue when we start using PartitionedDataset. - num_replicas = len( - self._input_workers.compute_devices_for_worker(self._worker_index)) - self._dataset = self._dataset.batch(num_replicas, drop_remainder=True) - else: - self._replica_devices = self._input_workers.compute_devices_for_worker( - self._worker_index) - - def make_one_shot_iterator(self): - """Get a one time use iterator for the distributed PerReplicaDataset.""" - # Graph mode with one shot iterator is disabled. - if not context.executing_eagerly(): - raise ValueError("Cannot create a one shot iterator. Please use " - "`make_initializable_iterator()` instead.") - if self._prefetch_on_device: - dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._replica_devices) - else: - dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset) - return PerReplicaDataIterator( - dataset_iterator, - self._input_workers, - self._worker_index, - prefetch_on_device=self._prefetch_on_device) - - def make_initializable_iterator(self): - """Get an initializable iterator for the distributed PerReplicaDataset.""" - # Eager mode generates already initialized iterators. Hence we cannot create - # an initializable iterator. - if context.executing_eagerly(): - raise ValueError("Cannot create initializable iterator in Eager mode. " - "Please use `make_one_shot_iterator` instead.") - if self._prefetch_on_device: - dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._replica_devices) - else: - dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset) - return PerReplicaDataIterator( - dataset_iterator, self._input_workers, self._worker_index, - prefetch_on_device=self._prefetch_on_device) - - -class MultiWorkerDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" - - def __init__(self, iterators, input_workers): - """Initialize the `MultiWorkerDataIterator` object. - - Args: - iterators: a list of worker, iterator pairs. - input_workers: an `InputWorkers` object. - - Raises: - ValueError: if iterators and input_workers are not compatible. - """ - assert isinstance(input_workers, InputWorkers) - workers = tuple(d for d, _ in iterators) - if workers != input_workers.worker_devices: - raise ValueError("iterators and input_workers are not compatible. " - "iterator workers: %r input_workers devices: %r" % - (workers, input_workers.worker_devices)) - self._iterators = tuple(i for _, i in iterators) - self._input_workers = input_workers - - @property - def initializer(self): - return control_flow_ops.group( - tuple(iterator.initializer for iterator in self._iterators)) - - def get_iterator(self, worker): - for i, w in enumerate(self._input_workers.worker_devices): - if worker == w: - return self._iterators[i] - return None - - @property - def output_shapes(self): - return self._iterators[0].output_shapes - - @property - def output_types(self): - return self._iterators[0].output_types - - def get_next(self, name=None): - """Scatter the input across hosts and devices.""" - replicas = [] - for worker, iterator in zip(self._input_workers.worker_devices, - self._iterators): - if name is not None: - d = tf_device.DeviceSpec.from_string(worker) - new_name = "%s_%s_%d" % (name, d.job, d.task) - else: - new_name = None - with ops.device(worker): - data_per_worker = iterator.get_next_as_list(name=new_name) - # Append to replicas to get a flat list of values indexed by replica. - replicas.extend(data_per_worker) - - return regroup(self._input_workers.device_map, replicas) - - -class MultiWorkerDataset(object): - """Like a `tf.data.Dataset` that distributes data to different workers. - - Each worker gets one shard of the input dataset. This currently does not work - in eager mode. - """ - - def __init__(self, dataset_fn, input_workers, prefetch_on_device=None, - auto_shard=False): - """Initialize the MultiWorkerDataset object. - - Args: - dataset_fn: a function or a list of functions that returns a - `tf.data.Dataset`. - input_workers: an `InputWorkers` object. - prefetch_on_device: whether to prefetch to devices. - auto_shard: whether to auto-shard the dataset. - """ - assert isinstance(input_workers, InputWorkers) - if isinstance(dataset_fn, (list, tuple)): - if len(dataset_fn) != input_workers.num_workers: - raise ValueError("If `dataset_fn` is a list, it must have one entry " - "per worker") - # TODO(rohanj): b/120673685 to track re-enabling auto sharding. - if auto_shard: - raise ValueError("Currently autosharding is not supported.") - self._input_workers = input_workers - self._datasets = [] - # TODO(yuefengz, priyag): support different set of jobs for input - # processing. - for i, worker in enumerate(input_workers.worker_devices): - with ops.device(worker): - if isinstance(dataset_fn, (list, tuple)): - worker_input = dataset_fn[i]() - else: - worker_input = dataset_fn() - dataset = PerReplicaDataset(worker_input, input_workers, i, - prefetch_on_device=prefetch_on_device) - self._datasets.append((worker, dataset)) - - def make_one_shot_iterator(self): - iterators = [] - for worker, dataset in self._datasets: - with ops.device(worker): - iterators.append((worker, dataset_ops.make_one_shot_iterator(dataset))) - return MultiWorkerDataIterator(iterators, self._input_workers) - - def make_initializable_iterator(self): - iterators = [] - for worker, dataset in self._datasets: - with ops.device(worker): - iterators.append( - (worker, dataset_ops.make_initializable_iterator(dataset))) - return MultiWorkerDataIterator(iterators, self._input_workers) - - -class InputIterator(object): - """An input iterator, intended to be passed to `DistributionStrategy.run`.""" - - def get_next(self): - """Returns the next inputs for all replicas.""" - raise NotImplementedError("must be implemented in descendants") - - def initialize(self): - """Initialize the underlying input dataset, when applicable. - - In eager mode, this will create a new iterator and return it. - In graph mode, this will initialize the same underlying iterator(s). - - Users are required to call this if - - This iterator was returned from a call to `make_input_fn_iterator` with an - input function that returns a dataset. - - Or this iterator was returned from a call to `make_dataset_iterator`. - - Returns: - A list of initialization ops to be executed. - """ - raise NotImplementedError("must be implemented in descendants") - - -class InputIteratorImpl(InputIterator): - """Common implementation for all input iterators.""" - - def __init__(self, input_workers, iterators): - assert isinstance(input_workers, InputWorkers) - if not input_workers.worker_devices: - raise ValueError("Should have at least one worker for input iterator.") - - self._iterators = iterators - self._input_workers = input_workers - - def get_next(self, name=None): - """Returns the next input from the iterator for all replicas.""" - replicas = [] - for i, worker in enumerate(self._input_workers.worker_devices): - if name is not None: - d = tf_device.DeviceSpec.from_string(worker) - new_name = "%s_%s_%d" % (name, d.job, d.task) - else: - new_name = None - with ops.device(worker): - # Make `replicas` a flat list of values across all replicas. - replicas.extend(self._iterators[i].get_next_as_list(new_name)) - - return regroup(self._input_workers.device_map, replicas) - - def initialize(self): - """Initialze underlying iterators. - - Returns: - A list of any initializer ops that should be run. - """ - init_ops = [] - for it in self._iterators: - init_ops.extend(it.initialize()) - return init_ops - - # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. - @property - def output_classes(self): - return self._iterators[0].output_classes - - # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. - @property - def output_shapes(self): - return self._iterators[0].output_shapes - - # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. - @property - def output_types(self): - return self._iterators[0].output_types - - # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. - def get_iterator(self, worker): - for i, w in enumerate(self._input_workers.worker_devices): - if worker == w: - return self._iterators[i] - return None - - -class InputFunctionIterator(InputIteratorImpl): - """Iterator created from input function.""" - - def __init__(self, input_fn, input_workers, input_contexts): - """Make an iterator for input provided via an input function. - - Currently implements PER_WORKER mode, in which the `input_fn` is called - once on each worker. - - TODO(priyag): Add other replication modes. - TODO(priyag): Allow taking input function that returns a callable that - returns nest of tensors. - - Args: - input_fn: Input function that returns a `tf.data.Dataset` object. - input_workers: an `InputWorkers` object. - input_contexts: A list of `InputContext` instances to be passed to call(s) - to `input_fn`. Length and order should match worker order in - `worker_device_pairs`. - """ - assert isinstance(input_workers, InputWorkers) - if input_workers.num_workers != len(input_contexts): - raise ValueError( - "Number of input workers (%d) is not same as number of " - "input_contexts (%d)" % - (input_workers.num_workers, len(input_contexts))) - - iterators = [] - for i, ctx in enumerate(input_contexts): - worker = input_workers.worker_devices[i] - with ops.device(worker): - result = input_fn(ctx) - if not isinstance(result, dataset_ops.DatasetV2): - raise ValueError("input_fn must return a tf.data.Dataset.") - devices = input_workers.compute_devices_for_worker(i) - iterator = _SingleWorkerDatasetIterator(result, worker, devices) - iterators.append(iterator) - - super(InputFunctionIterator, self).__init__(input_workers, iterators) - - -class DatasetIterator(InputIteratorImpl): - """Iterator created from input dataset.""" - - def __init__(self, dataset, input_workers, split_batch_by=None): - """Make an iterator for the dataset on given devices. - - If `split_batch_by` is not None, we "split" each batch of the - dataset by `split_batch_by` value. To achieve this, we first unbatch the - input dataset and then rebatch it with the per replica batch size that is - calculated using `global_batch_size // split_batch_by`. - The currently supported datasets are as follows: - `dataset.batch()` is the last operation on the dataset OR - `dataset.apply(map_and_batch)` is the last operation on the dataset OR - `dataset.batch().prefetch()` are the last 2 operations on the dataset OR - `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. - - TODO(priyag): Support multi worker / host cases properly by cloning - and sharding the dataset on each worker. Current setup will only work in - some cases, such as in-graph multi worker GPU case. If the input pipeline - has random shuffling (with a different seed on each worker), each worker - will see random input from the same overall dataset in each step. Otherwise, - each worker will see the same input in each step. - - Args: - dataset: `tf.data.Dataset` that will be used as the input source. - input_workers: an `InputWorkers` object. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. - """ - assert isinstance(input_workers, InputWorkers) - if split_batch_by: - dataset = _split_dataset_batch(dataset, split_batch_by) - - iterators = [] - for i, worker in enumerate(input_workers.worker_devices): - with ops.device(worker): - worker_devices = input_workers.compute_devices_for_worker(i) - cloned_dataset = dataset - if not context.executing_eagerly(): - cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access - iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, - worker_devices) - iterators.append(iterator) - - super(DatasetIterator, self).__init__(input_workers, iterators) - - -class _SingleWorkerDatasetIterator(object): - """Iterator for a single `tf.data.Dataset`.""" - - def __init__(self, dataset, worker, devices): - """Create iterator for the `dataset` to fetch data to worker's `devices` . - - `MultiDeviceIterator` is used to prefetch input to the devices on the - given worker. - - Args: - dataset: A `tf.data.Dataset` instance. - worker: Worker on which ops should be created. - devices: Distribute data from `dataset` to these devices. - """ - self._dataset = dataset - self._worker = worker - self._devices = devices - self._make_iterator() - - def _make_iterator(self): - """Make appropriate iterator on the dataset.""" - with ops.device(self._worker): - self._iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices) - - def get_next_as_list(self, name=None): - """Get next element from the underlying iterator.""" - del name - with ops.device(self._worker): - data_list = self._iterator.get_next() - return data_list - - def initialize(self): - """Initialze underlying iterator. - - In eager execution, this simply recreates the underlying iterator. - In graph execution, it returns the initializer ops for the underlying - iterator. - - Returns: - A list of any initializer ops that should be run. - """ - if context.executing_eagerly(): - self._make_iterator() - return [] - else: - return [self._iterator.initializer] - - @property - def output_classes(self): - return self._iterator.output_classes - - @property - def output_shapes(self): - return self._iterator.output_shapes - - @property - def output_types(self): - return self._iterator.output_types - - -def _split_dataset_batch(dataset, split_batch_by): - """Divide a batch-ed dataset's batches into smaller batches.""" - # TODO(sourabhbajaj): Remove this in lieu of distributed datasets - # pylint: disable=protected-access - def _get_batch_dataset(d): - """Get the underlying batch dataset from the dataset object.""" - if isinstance(d, dataset_ops.DatasetV1Adapter): - d = d._dataset - - if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): - return d - elif isinstance(d, dataset_ops.PrefetchDataset): - return _get_batch_dataset(d._input_dataset) - raise ValueError( - "Unable to get batched dataset from the input dataset. `batch` " - "`map_and_batch` need to be the last operations on the dataset. " - "The batch operations can be followed by a prefetch.") - - batched_dataset = _get_batch_dataset(dataset) - if isinstance(batched_dataset, dataset_ops.BatchDataset): - batch_size = batched_dataset._batch_size - drop_remainder = batched_dataset._drop_remainder - elif isinstance(batched_dataset, batching._MapAndBatchDataset): - batch_size = batched_dataset._batch_size_t - drop_remainder = batched_dataset._drop_remainder_t - - prefetch_buffer = None - if isinstance(dataset, dataset_ops.PrefetchDataset): - prefetch_buffer = dataset._buffer_size - elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) - and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): - prefetch_buffer = dataset._dataset._buffer_size - # pylint: enable=protected-access - - if tensor_util.is_tensor(batch_size): - batch_size = tensor_util.constant_value(batch_size) - - if tensor_util.is_tensor(drop_remainder): - drop_remainder = tensor_util.constant_value(drop_remainder) - - if batch_size % split_batch_by: - raise ValueError( - "Batch size %s cannot be sharded evenly across replicas %s" % ( - batch_size, split_batch_by)) - new_batch_size = batch_size // split_batch_by - - dataset = dataset.apply(batching.unbatch()) - dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder) - if prefetch_buffer is not None: - dataset = dataset.prefetch(prefetch_buffer) - return dataset - - -class MultiStepContext(object): - """A context object that can be used to capture things when running steps. - - This context object is useful when running multiple steps at a time using the - `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step - function to specify which outputs to emit at what frequency. Currently it - supports capturing output from the last step, as well as capturing non tensor - outputs. In the future it will be augmented to support other use cases such - as output each N steps. - """ - - def __init__(self): - """Initialize an output context. - - Returns: - A context object. - """ - self._last_step_outputs = {} - self._last_step_outputs_reduce_ops = {} - self._non_tensor_outputs = {} - - @property - def last_step_outputs(self): - """A dictionary consisting of outputs to be captured on last step. - - Keys in the dictionary are names of tensors to be captured, as specified - when `set_last_step_output` is called. - Values in the dictionary are the tensors themselves. If - `set_last_step_output` was called with a `reduce_op` for this output, - then the value is the reduced value. - - Returns: - A dictionary with last step outputs. - """ - return self._last_step_outputs - - def _set_last_step_outputs(self, outputs): - """Replace the entire dictionary of last step outputs.""" - if not isinstance(outputs, dict): - raise ValueError("Need a dictionary to set last_step_outputs.") - self._last_step_outputs = outputs - - def set_last_step_output(self, name, output, reduce_op=None): - """Set `output` with `name` to be outputted from the last step. - - Args: - name: String, name to identify the output. Doesn't need to match tensor - name. - output: The tensors that should be outputted with `name`. See below for - actual types supported. - reduce_op: Reduction method to use to reduce outputs from multiple - replicas. Required if `set_last_step_output` is called in a replica - context. Optional in cross_replica_context. - When present, the outputs from all the replicas are reduced using the - current distribution strategy's `reduce` method. Hence, the type of - `output` must be what's supported by the corresponding `reduce` method. - For e.g. if using MirroredStrategy and reduction is set, output - must be a `PerReplica` value. - The reduce method is also recorded in a dictionary - `_last_step_outputs_reduce_ops` for later interpreting of the - outputs as already reduced or not. - """ - if distribution_strategy_context.in_cross_replica_context(): - self._last_step_outputs_reduce_ops[name] = reduce_op - if reduce_op is None: - self._last_step_outputs[name] = output - else: - distribution = distribution_strategy_context.get_distribution_strategy() - self._last_step_outputs[name] = distribution.reduce(reduce_op, output) - else: - assert reduce_op is not None - def merge_fn(distribution, value): - self._last_step_outputs[name] = distribution.reduce(reduce_op, value) - # Setting this inside the `merge_fn` because all replicas share the same - # context object, so it's more robust to set it only once (even if all - # the replicas are trying to set the same value). - self._last_step_outputs_reduce_ops[name] = reduce_op - - distribution_strategy_context.get_replica_context().merge_call( - merge_fn, args=(output,)) - - @property - def non_tensor_outputs(self): - """A dictionary consisting of any non tensor outputs to be captured.""" - return self._non_tensor_outputs - - def set_non_tensor_output(self, name, output): - """Set `output` with `name` to be captured as a non tensor output.""" - if distribution_strategy_context.in_cross_replica_context(): - self._non_tensor_outputs[name] = output - else: - def merge_fn(distribution, value): - # NOTE(priyag): For non tensor outputs, we simply return all the values - # in a list as reduction doesn't make sense on non tensors. - self._non_tensor_outputs[name] = distribution.unwrap(value) - distribution_strategy_context.get_replica_context().merge_call( - merge_fn, args=(output,)) - - def value_container(val): """Returns the container that this per-replica `value` belongs to.