From 0a88318ac73a3d30dfee4ab1541070cbf82fe05c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Jan 2019 14:04:07 -0800 Subject: [PATCH] Add `tf.distribute.Strategy.experimental_make_numpy_iterator()` function. PiperOrigin-RevId: 228584021 --- tensorflow/contrib/distribute/python/BUILD | 4 + .../python/collective_all_reduce_strategy.py | 8 +- .../collective_all_reduce_strategy_test.py | 7 ++ .../contrib/distribute/python/keras_test.py | 58 +++------ .../distribute/python/mirrored_strategy.py | 30 +++++ .../python/mirrored_strategy_multigpu_test.py | 3 + .../distribute/python/one_device_strategy.py | 9 +- .../python/one_device_strategy_test.py | 4 + .../python/parameter_server_strategy.py | 31 +++++ .../python/parameter_server_strategy_test.py | 12 ++ .../distribute/python/strategy_test_lib.py | 29 +++++ .../contrib/distribute/python/tpu_strategy.py | 9 ++ tensorflow/python/distribute/BUILD | 32 +++++ .../python/distribute/distribute_lib.py | 72 +++++++++++- .../python/distribute/mirrored_strategy.py | 10 ++ tensorflow/python/distribute/numpy_dataset.py | 97 +++++++++++++++ .../python/distribute/numpy_dataset_test.py | 44 +++++++ .../distribute/parameter_server_strategy.py | 13 +- .../engine/distributed_training_utils.py | 111 +----------------- tensorflow/python/keras/engine/training.py | 60 +++++----- .../keras/engine/training_distributed.py | 6 - tensorflow/python/training/optimizer.py | 13 +- ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...orflow.distribute.-strategy-extended.pbtxt | 4 + .../v1/tensorflow.distribute.-strategy.pbtxt | 4 + ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...orflow.distribute.-strategy-extended.pbtxt | 4 + .../v2/tensorflow.distribute.-strategy.pbtxt | 4 + 28 files changed, 489 insertions(+), 197 deletions(-) create mode 100644 tensorflow/python/distribute/numpy_dataset.py create mode 100644 tensorflow/python/distribute/numpy_dataset_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 7ee12812af3..d4758d7518f 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -131,6 +131,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -153,6 +154,7 @@ py_library( "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], @@ -178,6 +180,7 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", ], ) @@ -303,6 +306,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index f6361cb6e89..39756b32c57 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -28,6 +28,7 @@ from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import ops @@ -86,6 +87,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) @@ -121,6 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) if num_gpus_per_worker: local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) @@ -157,6 +160,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device @@ -347,4 +353,4 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): - return False + return True diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 4c8c01a216a..c3e9c55e96c 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -466,6 +466,13 @@ class LocalCollectiveAllReduceStrategy( with self.cached_session(config=config, target=target): self._test_all_reduce_mean_gradient_tape(distribution) + def testNumpyIterator(self): + num_gpus = 2 + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + strategy, _, _ = self._get_test_object(None, None, num_gpus) + self._test_numpy_iterator(strategy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 1cb5fa30a3e..40916afcfaa 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -429,15 +429,6 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -576,26 +567,26 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) - # Call fit with validation data - model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, - validation_data=(inputs, targets)) + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) - # TODO(anjalisridhar): We need tests for when the batch size and steps are - # smaller and results in a 0 batch_size and steps value. - model.evaluate(inputs, targets) - # with steps - model.evaluate(inputs, targets, steps=2) - # with batch_size - model.evaluate(inputs, targets, batch_size=8) + # TODO(anjalisridhar): We need tests for when the batch size and steps + # are smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) - model.predict(inputs) - # with steps - model.predict(inputs, steps=2) - # with batch_size - model.predict(inputs, batch_size=8) + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): @@ -1192,21 +1183,6 @@ class TestDistributionStrategyValidation(test.TestCase, metrics = ['mae', keras.metrics.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics) - @combinations.generate(all_strategy_combinations_minus_default()) - def test_loop_in_scope(self, distribution): - with self.cached_session(): - with self.assertRaisesRegexp( - RuntimeError, 'should not be run inside the tf.distribute.Strategy'): - with distribution.scope(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - model.compile(optimizer, loss) - input_array = np.zeros((3, 3), dtype=np.float32) - model.predict(input_array) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index e3ab2bf19e5..5fa36fb4027 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -104,6 +104,36 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): auto_shard_dataset) super(MirroredStrategy, self).__init__(extended) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `MirroredStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9821828b2d6..59d711ae019 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -116,6 +116,9 @@ class MirroredTwoDeviceDistributionTest( self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index fb470f8546f..34b0c31087d 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -50,8 +51,8 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] + self._input_device = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(self._input_device, [self._device])] device_map = values.SingleDeviceMap(device) self._input_workers = input_lib.InputWorkers( device_map, worker_device_pairs) @@ -82,6 +83,10 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): return input_lib.InputFunctionIterator( input_fn, self._input_workers, [distribute_lib.InputContext()]) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, self._input_device, session) + def _broadcast_to(self, tensor, destinations): del destinations return tensor diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 2403dc8f125..f81466a6c75 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -59,6 +59,10 @@ class OneDeviceStrategyTest( self._test_input_fn_iterator( iterator, d.extended.worker_devices, expected_values) + @test_util.run_in_graph_and_eager_modes + def testNumpyIterator(self): + self._test_numpy_iterator(self._get_distribution_strategy()) + def testAllReduceSum(self): self._test_all_reduce_sum(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index bb0b8eb9927..0785427c2c0 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -89,6 +89,37 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): super(ParameterServerStrategy, self).__init__( ParameterServerExtended(self, num_gpus_per_worker)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `ParameterServerStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(ParameterServerStrategy, + self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class ParameterServerExtended(CoreParameterServerExtended): """Implementation of ParameterServerStrategy.""" diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 7836687e7d6..802809e7c7e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -893,5 +893,17 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, strategy.extended.call_for_each_replica(f) +class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + use_core_strategy=[True, False], + required_gpus=2)) + def testNumpyIterator(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 4fbd630cf72..7455cbd02a2 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context @@ -295,6 +297,33 @@ class DistributionTestBase(test.TestCase): global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) + def _test_numpy_iterator(self, strategy): + with strategy.scope(), self.cached_session() as sess: + x = np.asarray([[1, 2], [6, 12], [2, 4], + [5, 10], [3, 6], [4, 8]]) + y = np.asarray([5, 4, 3, 2, 1, 0]) + batch_size = 6 + if not strategy.extended._global_batch_size: # pylint: disable=protected-access + batch_size = batch_size // strategy.num_replicas_in_sync + i = strategy.experimental_make_numpy_iterator( + (x, y), batch_size=batch_size, num_epochs=2, shuffle=None, + session=sess) + self.evaluate(i.initialize()) + + def run_and_concatenate(strategy, i): + x, y = strategy.experimental_run(lambda z: z, i) + x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + return np.concatenate(x), np.concatenate(y) + + x_1, y_1 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_1) + self.assertAllEqual(y, y_1) + x_2, y_2 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_2) + self.assertAllEqual(y, y_2) + with self.assertRaises(errors.OutOfRangeError): + run_and_concatenate(strategy, i) + class OneDeviceDistributionTestBase(test.TestCase): """Some tests that should work with any one-device DistributionStrategy.""" diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 00360bf9960..3f89c5869ed 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -34,6 +34,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib @@ -303,6 +304,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), + session) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. @@ -466,6 +472,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index faaa61934ac..a6a1c470b41 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -124,6 +124,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":device_util", + ":numpy_dataset", ":reduce_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -221,6 +222,7 @@ py_library( ":distribute_lib", ":input_lib", ":multi_worker_util", + ":numpy_dataset", ":reduce_util", ":shared_variable_creator", ":values", @@ -249,6 +251,7 @@ py_library( deps = [ ":input_lib", ":mirrored_strategy", + ":numpy_dataset", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", @@ -276,6 +279,35 @@ py_library( ], ) +py_library( + name = "numpy_dataset", + srcs = ["numpy_dataset.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", + ], +) + +py_test( + name = "numpy_dataset_test", + size = "small", + srcs = ["numpy_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":numpy_dataset", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + "//third_party/py/numpy", + ], +) + py_library( name = "input_lib", srcs = ["input_lib.py"], diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 4ad8cc00b8a..31213ab6472 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -26,6 +26,7 @@ import enum from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context as eager_context from tensorflow.python.framework import constant_op @@ -360,7 +361,7 @@ class DistributionStrategy(object): return self._extended._distribute_dataset(dataset_fn) # pylint: disable=protected-access def make_dataset_iterator(self, dataset): - """Makes an iterator for input provided via input_dataset. + """Makes an iterator for input provided via `dataset`. Data from the given dataset will be distributed evenly across all the compute replicas. We will assume that the input dataset is batched by the @@ -418,6 +419,40 @@ class DistributionStrategy(object): return self.extended._make_input_fn_iterator( # pylint: disable=protected-access input_fn, replication_mode=replication_mode) + def experimental_make_numpy_iterator( + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. Note that lists of Numpy arrays are stacked, + as that is normal `tf.data.Dataset` behavior. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the global batch + size. It should be divisible by `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + ds = self.extended.experimental_make_numpy_dataset( + numpy_input, session=session) + if shuffle: + ds = ds.shuffle(shuffle) + if num_epochs != 1: + ds = ds.repeat(num_epochs) + # We need to use the drop_remainder argument to get a known static + # input shape which is required for TPUs. + drop_remainder = self.extended.experimental_require_static_shapes + ds = ds.batch(batch_size, drop_remainder=drop_remainder) + return self.make_dataset_iterator(ds) + def experimental_run(self, fn, input_iterator=None): """Runs ops in `fn` on each replica, with inputs from `input_iterator`. @@ -1083,6 +1118,29 @@ class DistributionStrategyExtended(object): def _make_input_fn_iterator(self, input_fn, replication_mode): raise NotImplementedError("must be implemented in descendants") + def experimental_make_numpy_dataset(self, numpy_input, session=None): + """Makes a dataset for input provided via a numpy array. + + This avoids adding `numpy_input` as a large constant in the graph, + and copies the data to the machine or machines that will be processing + the input. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. Note that lists of Numpy arrays are stacked, + as that is normal `tf.data.Dataset` behavior. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + A `tf.data.Dataset` representing `numpy_input`. + """ + _require_cross_replica_context_extended(self) + return self._experimental_make_numpy_dataset(numpy_input, session=session) + + def _experimental_make_numpy_dataset(self, numpy_input, session): + raise NotImplementedError("must be implemented in descendants") + def broadcast_to(self, tensor, destinations): """Mirror a tensor on one device to all worker devices. @@ -1660,6 +1718,18 @@ class _DefaultDistributionExtended(DistributionStrategyExtended): replication_mode=InputReplicationMode.PER_WORKER): return input_fn(InputContext()).make_initializable_iterator() + def _experimental_make_numpy_dataset(self, numpy_input, session): + numpy_flat = nest.flatten(numpy_input) + vars_flat = tuple( + variable_scope.variable(array_ops.zeros(i.shape, i.dtype), + trainable=False, use_resource=True) + for i in numpy_flat + ) + for v, i in zip(vars_flat, numpy_flat): + numpy_dataset.init_var_from_numpy(v, i, session) + vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) + return dataset_ops.Dataset.from_tensor_slices(vars_nested) + def _broadcast_to(self, tensor, destinations): if destinations is None: return tensor diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 37b493d0f70..c0a39d4b559 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -29,6 +29,7 @@ from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import shared_variable_creator from tensorflow.python.distribute import values @@ -460,6 +461,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): self._input_workers = input_lib.InputWorkers(self._device_map) self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best( devices) + self._host_input_device = numpy_dataset.SingleDevice("/cpu:0") def _initialize_multi_worker(self, devices): """Initializes the object for multi-worker training.""" @@ -488,6 +490,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): # their ops will end up on the cpu device of its first worker, e.g. # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. self._default_device = workers[0] + self._host_input_device = numpy_dataset.SingleDevice(workers[0]) self._device_map = values.ReplicaDeviceMap(devices) self._input_workers = input_lib.InputWorkers( @@ -501,6 +504,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device @@ -571,6 +577,10 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): return input_lib.InputFunctionIterator( input_fn, self._input_workers, input_contexts) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, self._host_input_device, session) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): diff --git a/tensorflow/python/distribute/numpy_dataset.py b/tensorflow/python/distribute/numpy_dataset.py new file mode 100644 index 00000000000..5881e4cd59e --- /dev/null +++ b/tensorflow/python/distribute/numpy_dataset.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Code for creating a dataset out of a NumPy array.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + + +def init_var_from_numpy(input_var, numpy_input, session): + """Initialize `input_var` to `numpy_input` using `session` in graph mode.""" + with ops.init_scope(): + if context.executing_eagerly(): + input_var.assign(numpy_input) + return + + assert session is not None + session.run(input_var.initializer) + + start_placeholder = array_ops.placeholder(dtypes.int64, ()) + end_placeholder = array_ops.placeholder(dtypes.int64, ()) + slice_placeholder = array_ops.placeholder(input_var.dtype) + assign_slice_op = input_var[start_placeholder:end_placeholder].assign( + slice_placeholder) + + # If each batch element is > 64 MB, then we copy each batch element + # individually. Otherwise, the slices will be < 128 MB. There might be + # padding which might mean that the slices are 128 MB even if the size of + # the tensor allocated is less than 128 MB. This formula gives slices with + # size: ceil(64 MB / byte size per batch element) bytes. Using ceil() + # guarantees we get a number >= 1. + + # Calculate the size of each batch element. + byte_size_per_batch_element = ( + np.prod(numpy_input.shape[1:]) * input_var.dtype.size) + + # Calculate number of elements we want to copy per slice. + batch_size_per_slice = int( + np.ceil((64 << 20) / byte_size_per_batch_element)) + + # Copy slices of the above size starting at 0, except the last slice will be + # smaller. + start = 0 + limit = numpy_input.shape[0] + while start < limit: + end = min(start + batch_size_per_slice, limit) + session.run(assign_slice_op, feed_dict={ + start_placeholder: start, + end_placeholder: end, + slice_placeholder: numpy_input[start:end]}) + start = end + + +def one_host_numpy_dataset(numpy_input, colocate_with, session): + """Create a dataset on `colocate_with` from `numpy_input`.""" + def create_colocated_variable(next_creator, *args, **kwargs): + kwargs["colocate_with"] = colocate_with + return next_creator(*args, **kwargs) + + numpy_flat = nest.flatten(numpy_input) + with variable_scope.variable_creator_scope(create_colocated_variable): + vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype), + trainable=False) + for i in numpy_flat) + for v, i in zip(vars_flat, numpy_flat): + init_var_from_numpy(v, i, session) + vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) + return dataset_ops.Dataset.from_tensor_slices(vars_nested) + + +class SingleDevice(object): + """Used with `colocate_with` to create a non-mirrored variable.""" + + def __init__(self, device): + self.device = device diff --git a/tensorflow/python/distribute/numpy_dataset_test.py b/tensorflow/python/distribute/numpy_dataset_test.py new file mode 100644 index 00000000000..04eae1daa2e --- /dev/null +++ b/tensorflow/python/distribute/numpy_dataset_test.py @@ -0,0 +1,44 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for numpy_dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.distribute import numpy_dataset +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope + + +class InitVarFromNumpyTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_creating_var_with_numpy_arrays(self): + with self.cached_session() as session: + x = np.asarray(np.random.random((64, 3)), dtype=np.float32) + initial = np.zeros_like(x) + var_x = variable_scope.variable(initial) + numpy_dataset.init_var_from_numpy(var_x, x, session) + val = self.evaluate(var_x.value()) + # Verify that the numpy value is copied to the variable. + self.assertAllEqual(x, val) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 71fbffdc0d8..ac5ee6f589f 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -26,6 +26,7 @@ from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver @@ -137,6 +138,7 @@ class ParameterServerStrategyExtended( assert cluster_spec.as_dict() worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._input_host_device = numpy_dataset.SingleDevice(worker_device) # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. @@ -195,6 +197,7 @@ class ParameterServerStrategyExtended( def _initialize_local(self, cluster_resolver): """Initialize internal devices for local training.""" worker_device = device_util.canonicalize("/device:CPU:0") + self._input_host_device = numpy_dataset.SingleDevice(worker_device) num_gpus = cluster_resolver.num_accelerators() # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. @@ -262,6 +265,10 @@ class ParameterServerStrategyExtended( return input_lib.InputFunctionIterator(input_fn, self._input_workers, [input_context]) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, self._input_host_device, session) + def _broadcast_to(self, tensor, destinations): # This is both a fast path for Python constants, and a way to delay # converting Python values to a tensor until we know what type it @@ -329,8 +336,12 @@ class ParameterServerStrategyExtended( var_creator = next_creator if "colocate_with" in kwargs: + colocate_with = kwargs["colocate_with"] + if isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return var_creator(*args, **kwargs) with ops.device(None): - with ops.colocate_with(kwargs["colocate_with"]): + with ops.colocate_with(colocate_with): return var_creator(*args, **kwargs) with ops.colocate_with(None, ignore_existing=True): diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index 1678d843079..b6ef33f700e 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -34,25 +34,13 @@ from tensorflow.python.keras import callbacks from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers from tensorflow.python.keras.optimizer_v2 import optimizer_v2 -from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training.mode_keys import ModeKeys from tensorflow.python.util import nest -def validate_not_in_strategy_scope(): - """Validate fit/eval/predict are not running in DS scope.""" - if distribution_strategy_context.has_distribution_strategy(): - if distribution_strategy_context.in_cross_replica_context(): - raise RuntimeError( - 'Fit/Eval/Predict should not be run inside the tf.distribute.Strategy' - ' scope. Only model creation and compilation should be in ' - 'tf.distribute.Strategy scope.') - - def set_weights(distribution_strategy, dist_model, weights): """Sets the weights of the replicated models. @@ -530,100 +518,11 @@ def get_batch_dimension(iterator): return dims[0] if dims else None -def get_cpu_device(distribution_strategy): - """Returns the CPU device of the TPU host or the default CPU device string. - - Args: - distribution_strategy: The DistributionStrategy used to compile the model. - - Returns: - A device string which is the TPU host's CPU device in case of - TPUDistributionStrategy or the default CPU device string in all other - cases. - - Raises: - NotImplementedError: We currently don't support copying numpy data to - multiple hosts in the case of Cloud TPU pods. - """ - if is_tpu_strategy(distribution_strategy): - if distribution_strategy.extended.num_hosts > 1: - raise NotImplementedError('TPUDistributionStrategy does not ' - 'support numpy inputs when running on Cloud' - 'TPU pods.') - return distribution_strategy.extended.get_host_cpu_device(0) - else: - # For all strategies except TPUDistributionStrategy - # TODO(anjalisridhar): We may need to modify this when we add support for - # multi-worker strategy. - return '/CPU:0' - - -def get_var_for_numpy(distribution_strategy, x): - if isinstance(x, list): - var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input) - for single_input in x]) - else: - var_x = _get_var_for_numpy(distribution_strategy, x) - return var_x - - -def _get_var_for_numpy(distribution_strategy, input_array): - """Creates a variable and assigns the value of the numpy array to it. - - Args: - distribution_strategy: The DistributionStrategy used to compile the model. - input_array: The input numpy array whose value will be assigned to the - variable we create. - - Returns: - The variable to which we will copy the value of the input numpy array. - - """ - with ops.device(get_cpu_device(distribution_strategy)): - # Create and initialize a variable on the CPU device. This is the CPU - # device of the host in the case of TPUDistributionStrategy. - input_var = variables.VariableV1(array_ops.zeros(input_array.shape, - input_array.dtype), - trainable=False, use_resource=True) - K.get_session().run(input_var.initializer) - - # Create a placeholder for the numpy array input slices. We copy the value - # of the input numpy array to the variable in slices of size 64 MB to avoid - # running into memory issues or RPC message limits. - start_placeholder = array_ops.placeholder(dtypes.int64, ()) - end_placeholder = array_ops.placeholder(dtypes.int64, ()) - slice_placeholder = array_ops.placeholder(input_var.dtype) - assign_slice_op = input_var[start_placeholder:end_placeholder].assign( - slice_placeholder) - - # If each batch element is > 64 MB, then we copy each batch element - # individually. Otherwise, the slices will be < 128 MB. There might be padding - # which might mean that the slices are 128 MB even if the size of the - # tensor allocated is less than 128 MB. - # This formula gives slices with size: - # ceil(64 MB / byte size per batch element) bytes. - # Using ceil() guarantees we get a number >= 1. - - # Calculate the size of each batch element. - byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \ - input_var.dtype.size - - # Calculate number of elements we want to copy per slice. - batch_size_per_slice = int(np.ceil((64 << 20) / byte_size_per_batch_element)) - - # Copy slices of the above size starting at 0, except the last slice will be - # smaller. - start = 0 - limit = input_array.shape[0] - while start < limit: - end = min(start + batch_size_per_slice, limit) - K.get_session().run(assign_slice_op, feed_dict={ - start_placeholder: start, - end_placeholder: end, - slice_placeholder: input_array[start:end]}) - start = end - - return input_var +def list_to_tuple(maybe_list): + """Datasets treat lists specially, so switch them to tuples.""" + if isinstance(maybe_list, list): + return tuple(maybe_list) + return maybe_list def _get_input_from_iterator(iterator, model): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 1eda8cf797b..a65c2b6413f 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2161,53 +2161,47 @@ class Model(Network): 'you should specify the `{steps_name}` argument.' .format(steps_name=steps_name)) - first_x_value = nest.flatten(x)[0] - if isinstance(first_x_value, np.ndarray): - # We need to use the drop_remainder argument to allow for a static - # input shape which is required for TPUs. - drop_remainder = self._distribution_strategy.require_static_shapes - if y is not None: - var_x = distributed_training_utils.get_var_for_numpy( - self._distribution_strategy, x) - var_y = distributed_training_utils.get_var_for_numpy( - self._distribution_strategy, y) - if sample_weight is not None: - var_sample_weights = distributed_training_utils.get_var_for_numpy( - self._distribution_strategy, sample_weight) + if ops.executing_eagerly_outside_functions(): + session = None + else: + session = K.get_session() - x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y, - var_sample_weights)) + with self._distribution_strategy.scope(): + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray): + x = distributed_training_utils.list_to_tuple(x) + if y is not None: + y = distributed_training_utils.list_to_tuple(y) + if sample_weight is not None: + sample_weight = distributed_training_utils.list_to_tuple( + sample_weight) + in_tuple = (x, y, sample_weight) + else: + in_tuple = (x, y) else: - x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y)) + in_tuple = x if shuffle: # 1024 is a good buffer size since it is much larger than the average # batch size provided by the user and provides sufficient randomness. # One thing to keep in mind is the memory usage based on the size of # each sample. - x = x.shuffle(1024) - x = x.repeat() - x = x.batch(batch_size, drop_remainder=drop_remainder) - y = None - sample_weight = None + shuffle_buffer = 1024 + else: + shuffle_buffer = None + iterator = self._distribution_strategy.experimental_make_numpy_iterator( + in_tuple, batch_size, num_epochs=None, shuffle=shuffle_buffer, + session=session) else: - # This case is for the predict call where the dataset only contains - # inputs and no targets, i.e. it does not return a tuple - var_x = distributed_training_utils.get_var_for_numpy( - self._distribution_strategy, x) - x = dataset_ops.Dataset.from_tensor_slices(var_x) - x = x.batch(batch_size, drop_remainder=drop_remainder) + assert isinstance(x, dataset_ops.DatasetV2) + training_utils.validate_dataset_input(x, y, sample_weight, + validation_split) + iterator = self._distribution_strategy.make_dataset_iterator(x) - assert isinstance(x, dataset_ops.DatasetV2) - - with self._distribution_strategy.scope(): - iterator = self._distribution_strategy.make_dataset_iterator(x) init_op = iterator.initialize() if not context.executing_eagerly(): K.get_session().run(init_op) - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) return iterator def _standardize_user_data(self, diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 92a46e399a9..0bd79f2b473 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -59,8 +59,6 @@ def fit_distributed(model, first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() steps_per_epoch, batch_size = ( distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps_per_epoch, @@ -141,8 +139,6 @@ def evaluate_distributed(model, distributed_training_utils.validate_inputs(x, y, model._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() steps, batch_size = distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps, batch_size) batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) @@ -179,8 +175,6 @@ def predict_distributed(model, x, None, model._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() steps, batch_size = distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps, batch_size) batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index eaa563e84aa..c6cc0b60440 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -554,14 +554,15 @@ class Optimizer( # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). - # Handle DistributionStrategy case. - if distribute_ctx.get_cross_replica_context(): - raise RuntimeError("Use `_distributed_apply()` instead of " - "`apply_gradients()` in a cross-replica context.") - # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by + # TODO(isaprykin): Get rid of `has_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. - if distribute_ctx.has_distribution_strategy(): + if distribute_ctx.has_strategy(): + # Handle DistributionStrategy case. + if distribute_ctx.in_cross_replica_context(): + raise RuntimeError("Use `_distributed_apply()` instead of " + "`apply_gradients()` in a cross-replica context.") + grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() return distribute_ctx.get_replica_context().merge_call( self._distributed_apply, args=(grads_and_vars, global_step, name)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt index b06c73d1260..9c29067b6d8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -75,6 +75,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_iterator" + argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], " + } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt index 77706e57133..37b620891fb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt @@ -50,6 +50,10 @@ tf_class { name: "colocate_vars_with" argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_dataset" + argspec: "args=[\'self\', \'numpy_input\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_run_steps_on_iterator" argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt index 9a1df551426..4aa6f1c4e14 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt @@ -74,6 +74,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_iterator" + argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], " + } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index b06c73d1260..9c29067b6d8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -75,6 +75,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_iterator" + argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], " + } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt index 77706e57133..37b620891fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt @@ -50,6 +50,10 @@ tf_class { name: "colocate_vars_with" argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_dataset" + argspec: "args=[\'self\', \'numpy_input\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "experimental_run_steps_on_iterator" argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 9a1df551426..4aa6f1c4e14 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -74,6 +74,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_make_numpy_iterator" + argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], " + } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "