Add DistStrat ReplicaContext.all_reduce
.
PiperOrigin-RevId: 228302155
This commit is contained in:
parent
07c6aa1a40
commit
81492c074a
@ -397,9 +397,11 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
|
|||||||
self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
|
self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
|
||||||
|
|
||||||
|
|
||||||
class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
|
class LocalCollectiveAllReduceStrategy(
|
||||||
strategy_test_lib.DistributionTestBase,
|
CollectiveAllReduceStrategyTestBase,
|
||||||
parameterized.TestCase):
|
strategy_test_lib.DistributionTestBase,
|
||||||
|
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
def testMinimizeLossGraph(self, num_gpus=2):
|
def testMinimizeLossGraph(self, num_gpus=2):
|
||||||
# Collective ops doesn't support strategy with one device.
|
# Collective ops doesn't support strategy with one device.
|
||||||
@ -428,6 +430,42 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
|
|||||||
self._test_input_fn_iterator(None, None, num_gpus,
|
self._test_input_fn_iterator(None, None, num_gpus,
|
||||||
input_fn, expected_values)
|
input_fn, expected_values)
|
||||||
|
|
||||||
|
def testAllReduceSum(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_sum(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradients(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_sum_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradientTape(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMean(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_mean(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradients(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_mean_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradientTape(self):
|
||||||
|
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||||
|
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||||
|
with self.cached_session(config=config, target=target):
|
||||||
|
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0]
|
|||||||
combinations.core_mirrored_strategy_with_gpu_and_cpu,
|
combinations.core_mirrored_strategy_with_gpu_and_cpu,
|
||||||
combinations.core_mirrored_strategy_with_two_gpus],
|
combinations.core_mirrored_strategy_with_two_gpus],
|
||||||
mode=["graph", "eager"]))
|
mode=["graph", "eager"]))
|
||||||
class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
|
class MirroredTwoDeviceDistributionTest(
|
||||||
parameterized.TestCase):
|
strategy_test_lib.DistributionTestBase,
|
||||||
|
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
def testMinimizeLoss(self, distribution):
|
def testMinimizeLoss(self, distribution):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -117,6 +119,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
|
|||||||
def testGlobalStepUpdate(self, distribution):
|
def testGlobalStepUpdate(self, distribution):
|
||||||
self._test_global_step_update(distribution)
|
self._test_global_step_update(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSum(self, distribution):
|
||||||
|
self._test_all_reduce_sum(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradients(self, distribution):
|
||||||
|
self._test_all_reduce_sum_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradientTape(self, distribution):
|
||||||
|
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMean(self, distribution):
|
||||||
|
self._test_all_reduce_mean(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradients(self, distribution):
|
||||||
|
self._test_all_reduce_mean_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradientTape(self, distribution):
|
||||||
|
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||||
|
|
||||||
|
|
||||||
def one_device_combinations():
|
def one_device_combinations():
|
||||||
return combinations.combine(
|
return combinations.combine(
|
||||||
@ -128,25 +148,42 @@ def one_device_combinations():
|
|||||||
mode=["graph", "eager"])
|
mode=["graph", "eager"])
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(one_device_combinations())
|
||||||
class MirroredOneDeviceDistributionTest(
|
class MirroredOneDeviceDistributionTest(
|
||||||
strategy_test_lib.DistributionTestBase,
|
strategy_test_lib.DistributionTestBase,
|
||||||
|
strategy_test_lib.OneDeviceDistributionTestBase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(one_device_combinations())
|
|
||||||
def testMinimizeLoss(self, distribution):
|
def testMinimizeLoss(self, distribution):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self._test_minimize_loss_eager(distribution)
|
self._test_minimize_loss_eager(distribution)
|
||||||
else:
|
else:
|
||||||
self._test_minimize_loss_graph(distribution)
|
self._test_minimize_loss_graph(distribution)
|
||||||
|
|
||||||
@combinations.generate(one_device_combinations())
|
|
||||||
def testReplicaId(self, distribution):
|
def testReplicaId(self, distribution):
|
||||||
self._test_replica_id(distribution)
|
self._test_replica_id(distribution)
|
||||||
|
|
||||||
@combinations.generate(one_device_combinations())
|
|
||||||
def testCallAndMergeExceptions(self, distribution):
|
def testCallAndMergeExceptions(self, distribution):
|
||||||
self._test_call_and_merge_exceptions(distribution)
|
self._test_call_and_merge_exceptions(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSum(self, distribution):
|
||||||
|
self._test_all_reduce_sum(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradients(self, distribution):
|
||||||
|
self._test_all_reduce_sum_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradientTape(self, distribution):
|
||||||
|
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMean(self, distribution):
|
||||||
|
self._test_all_reduce_mean(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradients(self, distribution):
|
||||||
|
self._test_all_reduce_mean_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradientTape(self, distribution):
|
||||||
|
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||||
|
|
||||||
|
|
||||||
class MirroredStrategyVariableCreatorStackTest(
|
class MirroredStrategyVariableCreatorStackTest(
|
||||||
test.TestCase, parameterized.TestCase):
|
test.TestCase, parameterized.TestCase):
|
||||||
|
@ -25,7 +25,9 @@ from tensorflow.python.eager import test
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
class OneDeviceStrategyTest(
|
||||||
|
strategy_test_lib.DistributionTestBase,
|
||||||
|
strategy_test_lib.OneDeviceDistributionTestBase):
|
||||||
|
|
||||||
def _get_distribution_strategy(self):
|
def _get_distribution_strategy(self):
|
||||||
return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
|
return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
|
||||||
@ -57,6 +59,24 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
|||||||
self._test_input_fn_iterator(
|
self._test_input_fn_iterator(
|
||||||
iterator, d.extended.worker_devices, expected_values)
|
iterator, d.extended.worker_devices, expected_values)
|
||||||
|
|
||||||
|
def testAllReduceSum(self):
|
||||||
|
self._test_all_reduce_sum(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
def testAllReduceSumGradients(self):
|
||||||
|
self._test_all_reduce_sum_gradients(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
def testAllReduceSumGradientTape(self):
|
||||||
|
self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
def testAllReduceMean(self):
|
||||||
|
self._test_all_reduce_mean(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
def testAllReduceMeanGradients(self):
|
||||||
|
self._test_all_reduce_mean_gradients(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
def testAllReduceMeanGradientTape(self):
|
||||||
|
self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -603,9 +603,11 @@ class ParameterServerStrategyTestBase(
|
|||||||
self.assertEqual(expected_value, computed_value)
|
self.assertEqual(expected_value, computed_value)
|
||||||
|
|
||||||
|
|
||||||
class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
|
class ParameterServerStrategyTest(
|
||||||
strategy_test_lib.DistributionTestBase,
|
ParameterServerStrategyTestBase,
|
||||||
parameterized.TestCase):
|
strategy_test_lib.DistributionTestBase,
|
||||||
|
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -782,6 +784,36 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
|
|||||||
# Verify isolate_session_state
|
# Verify isolate_session_state
|
||||||
self.assertTrue(new_config.isolate_session_state)
|
self.assertTrue(new_config.isolate_session_state)
|
||||||
|
|
||||||
|
def testAllReduceSum(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_sum(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradients(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_sum_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceSumGradientTape(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMean(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_mean(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradients(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_mean_gradients(distribution)
|
||||||
|
|
||||||
|
def testAllReduceMeanGradientTape(self):
|
||||||
|
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||||
|
|
||||||
|
|
||||||
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
@ -31,6 +32,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.layers import core
|
from tensorflow.python.layers import core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -292,3 +294,163 @@ class DistributionTestBase(test.TestCase):
|
|||||||
global_step_tensors = strategy.unwrap(value)
|
global_step_tensors = strategy.unwrap(value)
|
||||||
global_step_values = self.evaluate(global_step_tensors)
|
global_step_values = self.evaluate(global_step_tensors)
|
||||||
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
|
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
|
||||||
|
|
||||||
|
|
||||||
|
class OneDeviceDistributionTestBase(test.TestCase):
|
||||||
|
"""Some tests that should work with any one-device DistributionStrategy."""
|
||||||
|
|
||||||
|
def _test_all_reduce_sum(self, strategy):
|
||||||
|
self._test_collective_comms(
|
||||||
|
strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
|
||||||
|
|
||||||
|
def _test_all_reduce_sum_gradients(self, strategy):
|
||||||
|
self._test_collective_comms_gradients(
|
||||||
|
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
|
||||||
|
|
||||||
|
def _test_all_reduce_sum_gradient_tape(self, strategy):
|
||||||
|
self._test_collective_comms_gradient_tape(
|
||||||
|
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
|
||||||
|
|
||||||
|
def _test_all_reduce_mean(self, strategy):
|
||||||
|
self._test_collective_comms(
|
||||||
|
strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
|
||||||
|
|
||||||
|
def _test_all_reduce_mean_gradients(self, strategy):
|
||||||
|
self._test_collective_comms_gradients(
|
||||||
|
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
|
||||||
|
|
||||||
|
def _test_all_reduce_mean_gradient_tape(self, strategy):
|
||||||
|
self._test_collective_comms_gradient_tape(
|
||||||
|
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
|
||||||
|
|
||||||
|
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
outputs = self.evaluate(
|
||||||
|
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||||
|
self.assertAllEqual([expected[0]], outputs[0])
|
||||||
|
self.assertAllEqual([expected[1]], outputs[1])
|
||||||
|
|
||||||
|
def _test_collective_comms_gradients(
|
||||||
|
self, strategy, comm_fn, inputs, expected_grads):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.skipTest("`tf.gradients` is not supported with eager execution.")
|
||||||
|
|
||||||
|
def step(c):
|
||||||
|
x = constant_op.constant(42.)
|
||||||
|
y = comm_fn(x) * c
|
||||||
|
return gradients_impl.gradients(y, [x])[0]
|
||||||
|
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected_grads,
|
||||||
|
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||||
|
|
||||||
|
def _test_collective_comms_gradient_tape(
|
||||||
|
self, strategy, comm_fn, inputs, expected_grads):
|
||||||
|
def step(c):
|
||||||
|
x = constant_op.constant(42.)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(x)
|
||||||
|
y = comm_fn(x) * c
|
||||||
|
return tape.gradient(y, x)
|
||||||
|
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected_grads,
|
||||||
|
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||||
|
|
||||||
|
|
||||||
|
class TwoDeviceDistributionTestBase(test.TestCase):
|
||||||
|
"""Some tests that should work with any two-device DistributionStrategy."""
|
||||||
|
|
||||||
|
def _test_all_reduce_sum(self, strategy):
|
||||||
|
self._test_collective_comms(
|
||||||
|
strategy, _all_sum,
|
||||||
|
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
|
||||||
|
expected=(4., [42., 43.]))
|
||||||
|
|
||||||
|
def _test_all_reduce_sum_gradients(self, strategy):
|
||||||
|
self._test_collective_comms_gradients(
|
||||||
|
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
|
||||||
|
|
||||||
|
def _test_all_reduce_sum_gradient_tape(self, strategy):
|
||||||
|
self._test_collective_comms_gradient_tape(
|
||||||
|
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
|
||||||
|
|
||||||
|
def _test_all_reduce_mean(self, strategy):
|
||||||
|
self._test_collective_comms(
|
||||||
|
strategy, _all_mean,
|
||||||
|
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
|
||||||
|
expected=(2., [21., 21.5]))
|
||||||
|
|
||||||
|
def _test_all_reduce_mean_gradients(self, strategy):
|
||||||
|
self._test_collective_comms_gradients(
|
||||||
|
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
|
||||||
|
|
||||||
|
def _test_all_reduce_mean_gradient_tape(self, strategy):
|
||||||
|
self._test_collective_comms_gradient_tape(
|
||||||
|
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
|
||||||
|
|
||||||
|
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
outputs = self.evaluate(
|
||||||
|
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||||
|
self.assertAllEqual([expected[0], expected[0]], outputs[0])
|
||||||
|
self.assertAllEqual([expected[1], expected[1]], outputs[1])
|
||||||
|
|
||||||
|
def _test_collective_comms_gradients(
|
||||||
|
self, strategy, comm_fn, inputs, expected_grads):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.skipTest("`tf.gradients` is not supported with eager execution.")
|
||||||
|
|
||||||
|
def step(c):
|
||||||
|
x = constant_op.constant(42.)
|
||||||
|
y = comm_fn(x) * c
|
||||||
|
return gradients_impl.gradients(y, [x])[0]
|
||||||
|
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected_grads,
|
||||||
|
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||||
|
|
||||||
|
def _test_collective_comms_gradient_tape(
|
||||||
|
self, strategy, comm_fn, inputs, expected_grads):
|
||||||
|
def step(c):
|
||||||
|
x = constant_op.constant(42.)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(x)
|
||||||
|
y = comm_fn(x) * c
|
||||||
|
return tape.gradient(y, x)
|
||||||
|
|
||||||
|
inputs = strategy.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||||
|
|
||||||
|
self.evaluate(inputs.initialize())
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected_grads,
|
||||||
|
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||||
|
|
||||||
|
|
||||||
|
def _all_sum(value):
|
||||||
|
ctx = ds_context.get_replica_context()
|
||||||
|
return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _all_mean(value):
|
||||||
|
ctx = ds_context.get_replica_context()
|
||||||
|
return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops.losses import losses_impl
|
from tensorflow.python.ops.losses import losses_impl
|
||||||
@ -1554,6 +1555,50 @@ class ReplicaContext(object):
|
|||||||
require_replica_context(self)
|
require_replica_context(self)
|
||||||
return (device_util.current(),)
|
return (device_util.current(),)
|
||||||
|
|
||||||
|
def all_reduce(self, reduce_op, value):
|
||||||
|
"""All-reduces the given `Tensor` nest across replicas.
|
||||||
|
|
||||||
|
If `all_reduce` is called in any replica, it must be called in all replicas.
|
||||||
|
The nested structure and `Tensor` shapes must be identical in all replicas.
|
||||||
|
|
||||||
|
IMPORTANT: The ordering of communications must be identical in all replicas.
|
||||||
|
|
||||||
|
Example with two replicas:
|
||||||
|
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
|
||||||
|
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
|
||||||
|
|
||||||
|
If `reduce_op` == `SUM`:
|
||||||
|
Result (on all replicas): {'a': 4, 'b': [42, 99]}
|
||||||
|
|
||||||
|
If `reduce_op` == `MEAN`:
|
||||||
|
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||||
|
value: The nested structure of `Tensor`s to all-reduced.
|
||||||
|
The structure must be compatible with `tf.nest`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` nest with the reduced `value`s from each replica.
|
||||||
|
"""
|
||||||
|
def batch_all_reduce(strategy, *value_flat):
|
||||||
|
return strategy.extended.batch_reduce_to(
|
||||||
|
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
|
||||||
|
|
||||||
|
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
|
||||||
|
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def grad_wrapper(*xs):
|
||||||
|
ys = self.merge_call(batch_all_reduce, args=xs)
|
||||||
|
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
|
||||||
|
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
|
||||||
|
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||||
|
else:
|
||||||
|
# TODO(cjfj): Implement gradients for other reductions.
|
||||||
|
reduced = nest.pack_sequence_as(
|
||||||
|
value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
|
||||||
|
return nest.map_structure(array_ops.prevent_gradient, reduced)
|
||||||
|
|
||||||
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
|
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
|
||||||
# all-reduce. It would return a function returning the result of reducing `t`
|
# all-reduce. It would return a function returning the result of reducing `t`
|
||||||
# across all replicas. The caller would wait to call this function until they
|
# across all replicas. The caller would wait to call this function until they
|
||||||
@ -1564,6 +1609,15 @@ class ReplicaContext(object):
|
|||||||
# to that point that the first result is needed. Most likely this can be
|
# to that point that the first result is needed. Most likely this can be
|
||||||
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_reduce_destination(x):
|
||||||
|
"""Returns the destinations for batch all-reduce."""
|
||||||
|
if isinstance(x, ops.Tensor): # One device strategies.
|
||||||
|
return x.device
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "all_reduce"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "merge_call"
|
name: "merge_call"
|
||||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
@ -26,6 +26,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "all_reduce"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "merge_call"
|
name: "merge_call"
|
||||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user