Implement gather/all_gather for DefaultStrategy and OneDeviceStrategy.

PiperOrigin-RevId: 335796948
Change-Id: I0fdbf629ae599f7d88967c23d3bb3c5dd1b0b0f6
This commit is contained in:
Xinyi Wang 2020-10-06 22:45:08 -07:00 committed by TensorFlower Gardener
parent ee4be8edb8
commit 6876c21157
3 changed files with 43 additions and 15 deletions

View File

@ -3193,7 +3193,7 @@ class ReplicaContext(object):
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_gather, args=xs)
# The gradient of an all-gather is itself an all-gather.
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
@ -3346,6 +3346,10 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
del reduce_op, destinations, experimental_hints
return value
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
del destinations, axis, experimental_hints
return value
def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.

View File

@ -383,6 +383,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
del reduce_op, destinations, experimental_hints
return value
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
del destinations, axis, experimental_hints
return value
def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.

View File

@ -115,6 +115,9 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
@ -238,8 +241,13 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
@ -284,7 +292,8 @@ class GatherTest(test.TestCase, parameterized.TestCase):
"""Different rank: [1,], [1, 2] -> raise error."""
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
@ -308,6 +317,9 @@ class GatherTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
@ -334,7 +346,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
self.assertAllClose(result, expected_result)
@ -409,7 +421,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
if not pure_eager:
run = def_function.function(run)
expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(result, expected_result)
@ -443,7 +455,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
if not pure_eager:
run = def_function.function(run)
expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(result, expected_result)
@ -469,7 +481,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
# 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
raise ValueError('Add your own expect according to num_replicas_in sync')
expected_per_replica_1 = [expect_1] * _get_num_devices_per_worker(strategy)
expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)
value_2 = constant_op.constant([[[1, 2], [1, 2]]])
@ -485,7 +497,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
# [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
raise ValueError('Add your own expect according to num_replicas_in sync')
expected_per_replica_2 = [expect_2] * _get_num_devices_per_worker(strategy)
expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)
def run(value):
value_1 = array_ops.identity(value)
@ -517,7 +529,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_per_replica = [expect] * _get_num_devices_per_worker(strategy)
expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.run(run)
for gathered_result in result:
@ -527,9 +539,13 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
@ -571,7 +587,8 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
"""Different rank: [1,], [1, 2] -> raise error."""
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
@ -601,10 +618,12 @@ def _make_indexed_slices(values, indices, dense_shape):
return tensor
def _get_num_devices_per_worker(strategy):
"""Returns the number of workers in the current cluster for multi-worker."""
resolver = strategy.cluster_resolver
return max(nest.flatten(resolver.num_accelerators())[0], 1)
def _get_num_replicas_per_client(strategy):
if isinstance(strategy, CollectiveAllReduceStrategy):
resolver = strategy.cluster_resolver
return max(nest.flatten(resolver.num_accelerators())[0], 1)
else:
return strategy.num_replicas_in_sync
@combinations.generate(