Add DistStrat ReplicaContext.all_reduce.

PiperOrigin-RevId: 228302155
This commit is contained in:
Chris Jones 2019-01-08 02:50:54 -08:00 committed by TensorFlower Gardener
parent 07c6aa1a40
commit 81492c074a
8 changed files with 363 additions and 12 deletions

View File

@ -397,9 +397,11 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
class LocalCollectiveAllReduceStrategy(
CollectiveAllReduceStrategyTestBase,
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLossGraph(self, num_gpus=2):
# Collective ops doesn't support strategy with one device.
@ -428,6 +430,42 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
self._test_input_fn_iterator(None, None, num_gpus,
input_fn, expected_values)
def testAllReduceSum(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean_gradient_tape(distribution)
if __name__ == '__main__':
test.main()

View File

@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0]
combinations.core_mirrored_strategy_with_gpu_and_cpu,
combinations.core_mirrored_strategy_with_two_gpus],
mode=["graph", "eager"]))
class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
class MirroredTwoDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
@ -117,6 +119,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
def testGlobalStepUpdate(self, distribution):
self._test_global_step_update(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
def one_device_combinations():
return combinations.combine(
@ -128,25 +148,42 @@ def one_device_combinations():
mode=["graph", "eager"])
@combinations.generate(one_device_combinations())
class MirroredOneDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.OneDeviceDistributionTestBase,
parameterized.TestCase):
@combinations.generate(one_device_combinations())
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution)
@combinations.generate(one_device_combinations())
def testReplicaId(self, distribution):
self._test_replica_id(distribution)
@combinations.generate(one_device_combinations())
def testCallAndMergeExceptions(self, distribution):
self._test_call_and_merge_exceptions(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
class MirroredStrategyVariableCreatorStackTest(
test.TestCase, parameterized.TestCase):

View File

@ -25,7 +25,9 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
class OneDeviceStrategyTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.OneDeviceDistributionTestBase):
def _get_distribution_strategy(self):
return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
@ -57,6 +59,24 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
self._test_input_fn_iterator(
iterator, d.extended.worker_devices, expected_values)
def testAllReduceSum(self):
self._test_all_reduce_sum(self._get_distribution_strategy())
def testAllReduceSumGradients(self):
self._test_all_reduce_sum_gradients(self._get_distribution_strategy())
def testAllReduceSumGradientTape(self):
self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy())
def testAllReduceMean(self):
self._test_all_reduce_mean(self._get_distribution_strategy())
def testAllReduceMeanGradients(self):
self._test_all_reduce_mean_gradients(self._get_distribution_strategy())
def testAllReduceMeanGradientTape(self):
self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy())
if __name__ == "__main__":
test.main()

View File

@ -603,9 +603,11 @@ class ParameterServerStrategyTestBase(
self.assertEqual(expected_value, computed_value)
class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
class ParameterServerStrategyTest(
ParameterServerStrategyTestBase,
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase):
@classmethod
def setUpClass(cls):
@ -782,6 +784,36 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
# Verify isolate_session_state
self.assertTrue(new_config.isolate_session_state)
def testAllReduceSum(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_all_reduce_mean_gradient_tape(distribution)
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
@ -31,6 +32,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@ -292,3 +294,163 @@ class DistributionTestBase(test.TestCase):
global_step_tensors = strategy.unwrap(value)
global_step_values = self.evaluate(global_step_tensors)
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""
def _test_all_reduce_sum(self, strategy):
self._test_collective_comms(
strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
def _test_all_reduce_sum_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
def _test_all_reduce_sum_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
def _test_all_reduce_mean(self, strategy):
self._test_collective_comms(
strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
def _test_all_reduce_mean_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
def _test_all_reduce_mean_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
self.assertAllEqual([expected[0]], outputs[0])
self.assertAllEqual([expected[1]], outputs[1])
def _test_collective_comms_gradients(
self, strategy, comm_fn, inputs, expected_grads):
if context.executing_eagerly():
self.skipTest("`tf.gradients` is not supported with eager execution.")
def step(c):
x = constant_op.constant(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
def _test_collective_comms_gradient_tape(
self, strategy, comm_fn, inputs, expected_grads):
def step(c):
x = constant_op.constant(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
return tape.gradient(y, x)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
class TwoDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any two-device DistributionStrategy."""
def _test_all_reduce_sum(self, strategy):
self._test_collective_comms(
strategy, _all_sum,
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
expected=(4., [42., 43.]))
def _test_all_reduce_sum_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
def _test_all_reduce_sum_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
def _test_all_reduce_mean(self, strategy):
self._test_collective_comms(
strategy, _all_mean,
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
expected=(2., [21., 21.5]))
def _test_all_reduce_mean_gradients(self, strategy):
self._test_collective_comms_gradients(
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
def _test_all_reduce_mean_gradient_tape(self, strategy):
self._test_collective_comms_gradient_tape(
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
self.assertAllEqual([expected[0], expected[0]], outputs[0])
self.assertAllEqual([expected[1], expected[1]], outputs[1])
def _test_collective_comms_gradients(
self, strategy, comm_fn, inputs, expected_grads):
if context.executing_eagerly():
self.skipTest("`tf.gradients` is not supported with eager execution.")
def step(c):
x = constant_op.constant(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
def _test_collective_comms_gradient_tape(
self, strategy, comm_fn, inputs, expected_grads):
def step(c):
x = constant_op.constant(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
return tape.gradient(y, x)
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
def _all_sum(value):
ctx = ds_context.get_replica_context()
return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
def _all_mean(value):
ctx = ds_context.get_replica_context()
return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
@ -1554,6 +1555,50 @@ class ReplicaContext(object):
require_replica_context(self)
return (device_util.current(),)
def all_reduce(self, reduce_op, value):
"""All-reduces the given `Tensor` nest across replicas.
If `all_reduce` is called in any replica, it must be called in all replicas.
The nested structure and `Tensor` shapes must be identical in all replicas.
IMPORTANT: The ordering of communications must be identical in all replicas.
Example with two replicas:
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
If `reduce_op` == `SUM`:
Result (on all replicas): {'a': 4, 'b': [42, 99]}
If `reduce_op` == `MEAN`:
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value: The nested structure of `Tensor`s to all-reduced.
The structure must be compatible with `tf.nest`.
Returns:
A `Tensor` nest with the reduced `value`s from each replica.
"""
def batch_all_reduce(strategy, *value_flat):
return strategy.extended.batch_reduce_to(
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_reduce, args=xs)
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
else:
# TODO(cjfj): Implement gradients for other reductions.
reduced = nest.pack_sequence_as(
value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
return nest.map_structure(array_ops.prevent_gradient, reduced)
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
# all-reduce. It would return a function returning the result of reducing `t`
# across all replicas. The caller would wait to call this function until they
@ -1564,6 +1609,15 @@ class ReplicaContext(object):
# to that point that the first result is needed. Most likely this can be
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
def _batch_reduce_destination(x):
"""Returns the destinations for batch all-reduce."""
if isinstance(x, ops.Tensor): # One device strategies.
return x.device
else:
return x
# ------------------------------------------------------------------------------

View File

@ -26,6 +26,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "

View File

@ -26,6 +26,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "