From 81492c074aa134d8756b361557ec2dd9aad3cdc1 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 8 Jan 2019 02:50:54 -0800 Subject: [PATCH] Add DistStrat `ReplicaContext.all_reduce`. PiperOrigin-RevId: 228302155 --- .../collective_all_reduce_strategy_test.py | 44 ++++- .../python/mirrored_strategy_multigpu_test.py | 47 ++++- .../python/one_device_strategy_test.py | 22 ++- .../python/parameter_server_strategy_test.py | 38 +++- .../distribute/python/strategy_test_lib.py | 162 ++++++++++++++++++ .../python/distribute/distribute_lib.py | 54 ++++++ ...nsorflow.distribute.-replica-context.pbtxt | 4 + ...nsorflow.distribute.-replica-context.pbtxt | 4 + 8 files changed, 363 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 0fb672dded7..4c8c01a216a 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -397,9 +397,11 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. @@ -428,6 +430,42 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, self._test_input_fn_iterator(None, None, num_gpus, input_fn, expected_values) + def testAllReduceSum(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradient_tape(distribution) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9cbc34412da..9821828b2d6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"])) -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class MirroredTwoDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLoss(self, distribution): if context.executing_eagerly(): @@ -117,6 +119,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + def one_device_combinations(): return combinations.combine( @@ -128,25 +148,42 @@ def one_device_combinations(): mode=["graph", "eager"]) +@combinations.generate(one_device_combinations()) class MirroredOneDeviceDistributionTest( strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase, parameterized.TestCase): - @combinations.generate(one_device_combinations()) def testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution) - @combinations.generate(one_device_combinations()) def testReplicaId(self, distribution): self._test_replica_id(distribution) - @combinations.generate(one_device_combinations()) def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + class MirroredStrategyVariableCreatorStackTest( test.TestCase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index d46cd6f529e..2403dc8f125 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -25,7 +25,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): +class OneDeviceStrategyTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase): def _get_distribution_strategy(self): return one_device_strategy.OneDeviceStrategy("/device:CPU:0") @@ -57,6 +59,24 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): self._test_input_fn_iterator( iterator, d.extended.worker_devices, expected_values) + def testAllReduceSum(self): + self._test_all_reduce_sum(self._get_distribution_strategy()) + + def testAllReduceSumGradients(self): + self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + + def testAllReduceSumGradientTape(self): + self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + + def testAllReduceMean(self): + self._test_all_reduce_mean(self._get_distribution_strategy()) + + def testAllReduceMeanGradients(self): + self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + + def testAllReduceMeanGradientTape(self): + self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index d8c6c50db23..7836687e7d6 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -603,9 +603,11 @@ class ParameterServerStrategyTestBase( self.assertEqual(expected_value, computed_value) -class ParameterServerStrategyTest(ParameterServerStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class ParameterServerStrategyTest( + ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -782,6 +784,36 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + def testAllReduceSum(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradient_tape(distribution) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 6e5280e3563..4fbd630cf72 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values @@ -31,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -292,3 +294,163 @@ class DistributionTestBase(test.TestCase): global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) + + +class OneDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any one-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0]], outputs[0]) + self.assertAllEqual([expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +class TwoDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any two-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(2., [21., 21.5])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0], expected[0]], outputs[0]) + self.assertAllEqual([expected[1], expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +def _all_sum(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) + + +def _all_mean(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 76cbdd53d9d..4ad8cc00b8a 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl @@ -1554,6 +1555,50 @@ class ReplicaContext(object): require_replica_context(self) return (device_util.current(),) + def all_reduce(self, reduce_op, value): + """All-reduces the given `Tensor` nest across replicas. + + If `all_reduce` is called in any replica, it must be called in all replicas. + The nested structure and `Tensor` shapes must be identical in all replicas. + + IMPORTANT: The ordering of communications must be identical in all replicas. + + Example with two replicas: + Replica 0 `value`: {'a': 1, 'b': [40, 1]} + Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} + + If `reduce_op` == `SUM`: + Result (on all replicas): {'a': 4, 'b': [42, 99]} + + If `reduce_op` == `MEAN`: + Result (on all replicas): {'a': 2, 'b': [21, 49.5]} + + Args: + reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. + value: The nested structure of `Tensor`s to all-reduced. + The structure must be compatible with `tf.nest`. + + Returns: + A `Tensor` nest with the reduced `value`s from each replica. + """ + def batch_all_reduce(strategy, *value_flat): + return strategy.extended.batch_reduce_to( + reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) + + if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: + # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. + @custom_gradient.custom_gradient + def grad_wrapper(*xs): + ys = self.merge_call(batch_all_reduce, args=xs) + # The gradient of an all-sum is itself an all-sum (all-mean, likewise). + return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) + return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) + else: + # TODO(cjfj): Implement gradients for other reductions. + reduced = nest.pack_sequence_as( + value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) + return nest.map_structure(array_ops.prevent_gradient, reduced) + # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient # all-reduce. It would return a function returning the result of reducing `t` # across all replicas. The caller would wait to call this function until they @@ -1564,6 +1609,15 @@ class ReplicaContext(object): # to that point that the first result is needed. Most likely this can be # implemented in terms of `merge_call()` and `batch_reduce_to()`. + +def _batch_reduce_destination(x): + """Returns the destinations for batch all-reduce.""" + if isinstance(x, ops.Tensor): # One device strategies. + return x.device + else: + return x + + # ------------------------------------------------------------------------------ diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt index df707e8920e..22f8160c964 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt @@ -26,6 +26,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "all_reduce" + argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "merge_call" argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt index df707e8920e..22f8160c964 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt @@ -26,6 +26,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "all_reduce" + argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "merge_call" argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "