Add DistStrat ReplicaContext.all_reduce.
				
					
				
			PiperOrigin-RevId: 228302155
This commit is contained in:
		
							parent
							
								
									07c6aa1a40
								
							
						
					
					
						commit
						81492c074a
					
				| @ -397,9 +397,11 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( | ||||
|         self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) | ||||
| 
 | ||||
| 
 | ||||
| class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, | ||||
|                                        strategy_test_lib.DistributionTestBase, | ||||
|                                        parameterized.TestCase): | ||||
| class LocalCollectiveAllReduceStrategy( | ||||
|     CollectiveAllReduceStrategyTestBase, | ||||
|     strategy_test_lib.DistributionTestBase, | ||||
|     strategy_test_lib.TwoDeviceDistributionTestBase, | ||||
|     parameterized.TestCase): | ||||
| 
 | ||||
|   def testMinimizeLossGraph(self, num_gpus=2): | ||||
|     # Collective ops doesn't support strategy with one device. | ||||
| @ -428,6 +430,42 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, | ||||
|     self._test_input_fn_iterator(None, None, num_gpus, | ||||
|                                  input_fn, expected_values) | ||||
| 
 | ||||
|   def testAllReduceSum(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_sum(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradients(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_sum_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradientTape(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_sum_gradient_tape(distribution) | ||||
| 
 | ||||
|   def testAllReduceMean(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_mean(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradients(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_mean_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradientTape(self): | ||||
|     if context.num_gpus() < 2: self.skipTest('Not enough GPUs') | ||||
|     distribution, target, config = self._get_test_object(None, None, num_gpus=2) | ||||
|     with self.cached_session(config=config, target=target): | ||||
|       self._test_all_reduce_mean_gradient_tape(distribution) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   test.main() | ||||
|  | ||||
| @ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] | ||||
|         combinations.core_mirrored_strategy_with_gpu_and_cpu, | ||||
|         combinations.core_mirrored_strategy_with_two_gpus], | ||||
|     mode=["graph", "eager"])) | ||||
| class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, | ||||
|                                         parameterized.TestCase): | ||||
| class MirroredTwoDeviceDistributionTest( | ||||
|     strategy_test_lib.DistributionTestBase, | ||||
|     strategy_test_lib.TwoDeviceDistributionTestBase, | ||||
|     parameterized.TestCase): | ||||
| 
 | ||||
|   def testMinimizeLoss(self, distribution): | ||||
|     if context.executing_eagerly(): | ||||
| @ -117,6 +119,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, | ||||
|   def testGlobalStepUpdate(self, distribution): | ||||
|     self._test_global_step_update(distribution) | ||||
| 
 | ||||
|   def testAllReduceSum(self, distribution): | ||||
|     self._test_all_reduce_sum(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradients(self, distribution): | ||||
|     self._test_all_reduce_sum_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradientTape(self, distribution): | ||||
|     self._test_all_reduce_sum_gradient_tape(distribution) | ||||
| 
 | ||||
|   def testAllReduceMean(self, distribution): | ||||
|     self._test_all_reduce_mean(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradients(self, distribution): | ||||
|     self._test_all_reduce_mean_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradientTape(self, distribution): | ||||
|     self._test_all_reduce_mean_gradient_tape(distribution) | ||||
| 
 | ||||
| 
 | ||||
| def one_device_combinations(): | ||||
|   return combinations.combine( | ||||
| @ -128,25 +148,42 @@ def one_device_combinations(): | ||||
|       mode=["graph", "eager"]) | ||||
| 
 | ||||
| 
 | ||||
| @combinations.generate(one_device_combinations()) | ||||
| class MirroredOneDeviceDistributionTest( | ||||
|     strategy_test_lib.DistributionTestBase, | ||||
|     strategy_test_lib.OneDeviceDistributionTestBase, | ||||
|     parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate(one_device_combinations()) | ||||
|   def testMinimizeLoss(self, distribution): | ||||
|     if context.executing_eagerly(): | ||||
|       self._test_minimize_loss_eager(distribution) | ||||
|     else: | ||||
|       self._test_minimize_loss_graph(distribution) | ||||
| 
 | ||||
|   @combinations.generate(one_device_combinations()) | ||||
|   def testReplicaId(self, distribution): | ||||
|     self._test_replica_id(distribution) | ||||
| 
 | ||||
|   @combinations.generate(one_device_combinations()) | ||||
|   def testCallAndMergeExceptions(self, distribution): | ||||
|     self._test_call_and_merge_exceptions(distribution) | ||||
| 
 | ||||
|   def testAllReduceSum(self, distribution): | ||||
|     self._test_all_reduce_sum(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradients(self, distribution): | ||||
|     self._test_all_reduce_sum_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradientTape(self, distribution): | ||||
|     self._test_all_reduce_sum_gradient_tape(distribution) | ||||
| 
 | ||||
|   def testAllReduceMean(self, distribution): | ||||
|     self._test_all_reduce_mean(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradients(self, distribution): | ||||
|     self._test_all_reduce_mean_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradientTape(self, distribution): | ||||
|     self._test_all_reduce_mean_gradient_tape(distribution) | ||||
| 
 | ||||
| 
 | ||||
| class MirroredStrategyVariableCreatorStackTest( | ||||
|     test.TestCase, parameterized.TestCase): | ||||
|  | ||||
| @ -25,7 +25,9 @@ from tensorflow.python.eager import test | ||||
| from tensorflow.python.framework import test_util | ||||
| 
 | ||||
| 
 | ||||
| class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): | ||||
| class OneDeviceStrategyTest( | ||||
|     strategy_test_lib.DistributionTestBase, | ||||
|     strategy_test_lib.OneDeviceDistributionTestBase): | ||||
| 
 | ||||
|   def _get_distribution_strategy(self): | ||||
|     return one_device_strategy.OneDeviceStrategy("/device:CPU:0") | ||||
| @ -57,6 +59,24 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): | ||||
|     self._test_input_fn_iterator( | ||||
|         iterator, d.extended.worker_devices, expected_values) | ||||
| 
 | ||||
|   def testAllReduceSum(self): | ||||
|     self._test_all_reduce_sum(self._get_distribution_strategy()) | ||||
| 
 | ||||
|   def testAllReduceSumGradients(self): | ||||
|     self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) | ||||
| 
 | ||||
|   def testAllReduceSumGradientTape(self): | ||||
|     self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) | ||||
| 
 | ||||
|   def testAllReduceMean(self): | ||||
|     self._test_all_reduce_mean(self._get_distribution_strategy()) | ||||
| 
 | ||||
|   def testAllReduceMeanGradients(self): | ||||
|     self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) | ||||
| 
 | ||||
|   def testAllReduceMeanGradientTape(self): | ||||
|     self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
|  | ||||
| @ -603,9 +603,11 @@ class ParameterServerStrategyTestBase( | ||||
|         self.assertEqual(expected_value, computed_value) | ||||
| 
 | ||||
| 
 | ||||
| class ParameterServerStrategyTest(ParameterServerStrategyTestBase, | ||||
|                                   strategy_test_lib.DistributionTestBase, | ||||
|                                   parameterized.TestCase): | ||||
| class ParameterServerStrategyTest( | ||||
|     ParameterServerStrategyTestBase, | ||||
|     strategy_test_lib.DistributionTestBase, | ||||
|     strategy_test_lib.TwoDeviceDistributionTestBase, | ||||
|     parameterized.TestCase): | ||||
| 
 | ||||
|   @classmethod | ||||
|   def setUpClass(cls): | ||||
| @ -782,6 +784,36 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, | ||||
|     # Verify isolate_session_state | ||||
|     self.assertTrue(new_config.isolate_session_state) | ||||
| 
 | ||||
|   def testAllReduceSum(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_sum(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradients(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_sum_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceSumGradientTape(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_sum_gradient_tape(distribution) | ||||
| 
 | ||||
|   def testAllReduceMean(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_mean(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradients(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_mean_gradients(distribution) | ||||
| 
 | ||||
|   def testAllReduceMeanGradientTape(self): | ||||
|     distribution = parameter_server_strategy.ParameterServerStrategy( | ||||
|         num_gpus_per_worker=2) | ||||
|     self._test_all_reduce_mean_gradient_tape(distribution) | ||||
| 
 | ||||
| 
 | ||||
| class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, | ||||
|                                            parameterized.TestCase): | ||||
|  | ||||
| @ -19,6 +19,7 @@ from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import distribution_strategy_context as ds_context | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.distribute import values | ||||
| @ -31,6 +32,7 @@ from tensorflow.python.framework import errors | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.layers import core | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import gradients_impl | ||||
| from tensorflow.python.ops import init_ops | ||||
| from tensorflow.python.ops import variable_scope | ||||
| from tensorflow.python.ops import variables | ||||
| @ -292,3 +294,163 @@ class DistributionTestBase(test.TestCase): | ||||
|       global_step_tensors = strategy.unwrap(value) | ||||
|       global_step_values = self.evaluate(global_step_tensors) | ||||
|       self.assertEqual((1,) * len(global_step_tensors), global_step_values) | ||||
| 
 | ||||
| 
 | ||||
| class OneDeviceDistributionTestBase(test.TestCase): | ||||
|   """Some tests that should work with any one-device DistributionStrategy.""" | ||||
| 
 | ||||
|   def _test_all_reduce_sum(self, strategy): | ||||
|     self._test_collective_comms( | ||||
|         strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) | ||||
| 
 | ||||
|   def _test_all_reduce_sum_gradients(self, strategy): | ||||
|     self._test_collective_comms_gradients( | ||||
|         strategy, _all_sum, inputs=[4.], expected_grads=[4.]) | ||||
| 
 | ||||
|   def _test_all_reduce_sum_gradient_tape(self, strategy): | ||||
|     self._test_collective_comms_gradient_tape( | ||||
|         strategy, _all_sum, inputs=[4.], expected_grads=[4.]) | ||||
| 
 | ||||
|   def _test_all_reduce_mean(self, strategy): | ||||
|     self._test_collective_comms( | ||||
|         strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) | ||||
| 
 | ||||
|   def _test_all_reduce_mean_gradients(self, strategy): | ||||
|     self._test_collective_comms_gradients( | ||||
|         strategy, _all_mean, inputs=[5.], expected_grads=[5.]) | ||||
| 
 | ||||
|   def _test_all_reduce_mean_gradient_tape(self, strategy): | ||||
|     self._test_collective_comms_gradient_tape( | ||||
|         strategy, _all_mean, inputs=[5.], expected_grads=[5.]) | ||||
| 
 | ||||
|   def _test_collective_comms(self, strategy, comm_fn, inputs, expected): | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensors(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     outputs = self.evaluate( | ||||
|         list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) | ||||
|     self.assertAllEqual([expected[0]], outputs[0]) | ||||
|     self.assertAllEqual([expected[1]], outputs[1]) | ||||
| 
 | ||||
|   def _test_collective_comms_gradients( | ||||
|       self, strategy, comm_fn, inputs, expected_grads): | ||||
|     if context.executing_eagerly(): | ||||
|       self.skipTest("`tf.gradients` is not supported with eager execution.") | ||||
| 
 | ||||
|     def step(c): | ||||
|       x = constant_op.constant(42.) | ||||
|       y = comm_fn(x) * c | ||||
|       return gradients_impl.gradients(y, [x])[0] | ||||
| 
 | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensors(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     self.assertAllEqual( | ||||
|         expected_grads, | ||||
|         self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) | ||||
| 
 | ||||
|   def _test_collective_comms_gradient_tape( | ||||
|       self, strategy, comm_fn, inputs, expected_grads): | ||||
|     def step(c): | ||||
|       x = constant_op.constant(42.) | ||||
|       with backprop.GradientTape() as tape: | ||||
|         tape.watch(x) | ||||
|         y = comm_fn(x) * c | ||||
|       return tape.gradient(y, x) | ||||
| 
 | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensors(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     self.assertAllEqual( | ||||
|         expected_grads, | ||||
|         self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) | ||||
| 
 | ||||
| 
 | ||||
| class TwoDeviceDistributionTestBase(test.TestCase): | ||||
|   """Some tests that should work with any two-device DistributionStrategy.""" | ||||
| 
 | ||||
|   def _test_all_reduce_sum(self, strategy): | ||||
|     self._test_collective_comms( | ||||
|         strategy, _all_sum, | ||||
|         inputs=([1., 3.], [[39., 2.], [3., 41.]]), | ||||
|         expected=(4., [42., 43.])) | ||||
| 
 | ||||
|   def _test_all_reduce_sum_gradients(self, strategy): | ||||
|     self._test_collective_comms_gradients( | ||||
|         strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) | ||||
| 
 | ||||
|   def _test_all_reduce_sum_gradient_tape(self, strategy): | ||||
|     self._test_collective_comms_gradient_tape( | ||||
|         strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) | ||||
| 
 | ||||
|   def _test_all_reduce_mean(self, strategy): | ||||
|     self._test_collective_comms( | ||||
|         strategy, _all_mean, | ||||
|         inputs=([1., 3.], [[39., 2.], [3., 41.]]), | ||||
|         expected=(2., [21., 21.5])) | ||||
| 
 | ||||
|   def _test_all_reduce_mean_gradients(self, strategy): | ||||
|     self._test_collective_comms_gradients( | ||||
|         strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) | ||||
| 
 | ||||
|   def _test_all_reduce_mean_gradient_tape(self, strategy): | ||||
|     self._test_collective_comms_gradient_tape( | ||||
|         strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) | ||||
| 
 | ||||
|   def _test_collective_comms(self, strategy, comm_fn, inputs, expected): | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     outputs = self.evaluate( | ||||
|         list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) | ||||
|     self.assertAllEqual([expected[0], expected[0]], outputs[0]) | ||||
|     self.assertAllEqual([expected[1], expected[1]], outputs[1]) | ||||
| 
 | ||||
|   def _test_collective_comms_gradients( | ||||
|       self, strategy, comm_fn, inputs, expected_grads): | ||||
|     if context.executing_eagerly(): | ||||
|       self.skipTest("`tf.gradients` is not supported with eager execution.") | ||||
| 
 | ||||
|     def step(c): | ||||
|       x = constant_op.constant(42.) | ||||
|       y = comm_fn(x) * c | ||||
|       return gradients_impl.gradients(y, [x])[0] | ||||
| 
 | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     self.assertAllEqual( | ||||
|         expected_grads, | ||||
|         self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) | ||||
| 
 | ||||
|   def _test_collective_comms_gradient_tape( | ||||
|       self, strategy, comm_fn, inputs, expected_grads): | ||||
|     def step(c): | ||||
|       x = constant_op.constant(42.) | ||||
|       with backprop.GradientTape() as tape: | ||||
|         tape.watch(x) | ||||
|         y = comm_fn(x) * c | ||||
|       return tape.gradient(y, x) | ||||
| 
 | ||||
|     inputs = strategy.make_input_fn_iterator( | ||||
|         lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) | ||||
| 
 | ||||
|     self.evaluate(inputs.initialize()) | ||||
|     self.assertAllEqual( | ||||
|         expected_grads, | ||||
|         self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) | ||||
| 
 | ||||
| 
 | ||||
| def _all_sum(value): | ||||
|   ctx = ds_context.get_replica_context() | ||||
|   return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) | ||||
| 
 | ||||
| 
 | ||||
| def _all_mean(value): | ||||
|   ctx = ds_context.get_replica_context() | ||||
|   return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) | ||||
|  | ||||
| @ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import custom_gradient | ||||
| from tensorflow.python.ops import resource_variable_ops | ||||
| from tensorflow.python.ops import variable_scope | ||||
| from tensorflow.python.ops.losses import losses_impl | ||||
| @ -1554,6 +1555,50 @@ class ReplicaContext(object): | ||||
|     require_replica_context(self) | ||||
|     return (device_util.current(),) | ||||
| 
 | ||||
|   def all_reduce(self, reduce_op, value): | ||||
|     """All-reduces the given `Tensor` nest across replicas. | ||||
| 
 | ||||
|     If `all_reduce` is called in any replica, it must be called in all replicas. | ||||
|     The nested structure and `Tensor` shapes must be identical in all replicas. | ||||
| 
 | ||||
|     IMPORTANT: The ordering of communications must be identical in all replicas. | ||||
| 
 | ||||
|     Example with two replicas: | ||||
|       Replica 0 `value`: {'a': 1, 'b': [40,  1]} | ||||
|       Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} | ||||
| 
 | ||||
|       If `reduce_op` == `SUM`: | ||||
|         Result (on all replicas): {'a': 4, 'b': [42, 99]} | ||||
| 
 | ||||
|       If `reduce_op` == `MEAN`: | ||||
|         Result (on all replicas): {'a': 2, 'b': [21, 49.5]} | ||||
| 
 | ||||
|     Args: | ||||
|       reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. | ||||
|       value: The nested structure of `Tensor`s to all-reduced. | ||||
|         The structure must be compatible with `tf.nest`. | ||||
| 
 | ||||
|     Returns: | ||||
|        A `Tensor` nest with the reduced `value`s from each replica. | ||||
|     """ | ||||
|     def batch_all_reduce(strategy, *value_flat): | ||||
|       return strategy.extended.batch_reduce_to( | ||||
|           reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) | ||||
| 
 | ||||
|     if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: | ||||
|       # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. | ||||
|       @custom_gradient.custom_gradient | ||||
|       def grad_wrapper(*xs): | ||||
|         ys = self.merge_call(batch_all_reduce, args=xs) | ||||
|         # The gradient of an all-sum is itself an all-sum (all-mean, likewise). | ||||
|         return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) | ||||
|       return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) | ||||
|     else: | ||||
|       # TODO(cjfj): Implement gradients for other reductions. | ||||
|       reduced = nest.pack_sequence_as( | ||||
|           value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) | ||||
|       return nest.map_structure(array_ops.prevent_gradient, reduced) | ||||
| 
 | ||||
|   # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient | ||||
|   # all-reduce. It would return a function returning the result of reducing `t` | ||||
|   # across all replicas. The caller would wait to call this function until they | ||||
| @ -1564,6 +1609,15 @@ class ReplicaContext(object): | ||||
|   #   to that point that the first result is needed. Most likely this can be | ||||
|   #   implemented in terms of `merge_call()` and `batch_reduce_to()`. | ||||
| 
 | ||||
| 
 | ||||
| def _batch_reduce_destination(x): | ||||
|   """Returns the destinations for batch all-reduce.""" | ||||
|   if isinstance(x, ops.Tensor):  # One device strategies. | ||||
|     return x.device | ||||
|   else: | ||||
|     return x | ||||
| 
 | ||||
| 
 | ||||
| # ------------------------------------------------------------------------------ | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -26,6 +26,10 @@ tf_class { | ||||
|     name: "__init__" | ||||
|     argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "all_reduce" | ||||
|     argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "merge_call" | ||||
|     argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " | ||||
|  | ||||
| @ -26,6 +26,10 @@ tf_class { | ||||
|     name: "__init__" | ||||
|     argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "all_reduce" | ||||
|     argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "merge_call" | ||||
|     argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user