Implement gather/all_gather for DefaultStrategy and OneDeviceStrategy.
PiperOrigin-RevId: 335796948 Change-Id: I0fdbf629ae599f7d88967c23d3bb3c5dd1b0b0f6
This commit is contained in:
parent
ee4be8edb8
commit
6876c21157
@ -3193,7 +3193,7 @@ class ReplicaContext(object):
|
||||
def grad_wrapper(*xs):
|
||||
ys = self.merge_call(batch_all_gather, args=xs)
|
||||
# The gradient of an all-gather is itself an all-gather.
|
||||
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
|
||||
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
|
||||
|
||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||
|
||||
@ -3346,6 +3346,10 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
return value
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
|
||||
del destinations, axis, experimental_hints
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# The implementations of _update() and _update_non_slot() are identical
|
||||
# except _update() passes `var` as the first argument to `fn()`.
|
||||
|
||||
@ -383,6 +383,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
return value
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
del destinations, axis, experimental_hints
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# The implementations of _update() and _update_non_slot() are identical
|
||||
# except _update() passes `var` as the first argument to `fn()`.
|
||||
|
||||
@ -115,6 +115,9 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
@ -238,8 +241,13 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
|
||||
"""Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
|
||||
if _get_num_devices_per_worker(strategy) > 1:
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy
|
||||
) and _get_num_replicas_per_client(strategy) > 1:
|
||||
self.skipTest('b/167331966')
|
||||
|
||||
if strategy.num_replicas_in_sync <= 1:
|
||||
self.skipTest('Test for more than 1 replica only.')
|
||||
|
||||
def value_fn(ctx):
|
||||
return constant_op.constant(
|
||||
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
||||
@ -284,7 +292,8 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
"""Different rank: [1,], [1, 2] -> raise error."""
|
||||
if strategy.num_replicas_in_sync <= 1:
|
||||
self.skipTest('Test for more than 1 replicas.')
|
||||
if _get_num_devices_per_worker(strategy) > 1:
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy
|
||||
) and _get_num_replicas_per_client(strategy) > 1:
|
||||
self.skipTest('b/167331966')
|
||||
def value_fn(ctx):
|
||||
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
|
||||
@ -308,6 +317,9 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
@ -334,7 +346,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
|
||||
expect = array_ops.concat(all_value, axis=axis)
|
||||
expected_result = [expect] * _get_num_devices_per_worker(strategy)
|
||||
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
||||
|
||||
self.assertAllClose(result, expected_result)
|
||||
|
||||
@ -409,7 +421,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
|
||||
expected_result = [expect] * _get_num_devices_per_worker(strategy)
|
||||
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
||||
result = strategy.experimental_local_results(
|
||||
strategy.run(run, args=(per_replica_value,)))
|
||||
self.assertAllEqual(result, expected_result)
|
||||
@ -443,7 +455,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
|
||||
expected_result = [expect] * _get_num_devices_per_worker(strategy)
|
||||
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
||||
result = strategy.experimental_local_results(
|
||||
strategy.run(run, args=(per_replica_value,)))
|
||||
self.assertAllEqual(result, expected_result)
|
||||
@ -469,7 +481,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
# 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
|
||||
raise ValueError('Add your own expect according to num_replicas_in sync')
|
||||
|
||||
expected_per_replica_1 = [expect_1] * _get_num_devices_per_worker(strategy)
|
||||
expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)
|
||||
|
||||
value_2 = constant_op.constant([[[1, 2], [1, 2]]])
|
||||
|
||||
@ -485,7 +497,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
# [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
|
||||
raise ValueError('Add your own expect according to num_replicas_in sync')
|
||||
|
||||
expected_per_replica_2 = [expect_2] * _get_num_devices_per_worker(strategy)
|
||||
expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)
|
||||
|
||||
def run(value):
|
||||
value_1 = array_ops.identity(value)
|
||||
@ -517,7 +529,7 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
|
||||
expect = array_ops.concat(all_value, axis=axis)
|
||||
expected_per_replica = [expect] * _get_num_devices_per_worker(strategy)
|
||||
expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)
|
||||
|
||||
result = strategy.run(run)
|
||||
for gathered_result in result:
|
||||
@ -527,9 +539,13 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
|
||||
"""Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
|
||||
if _get_num_devices_per_worker(strategy) > 1:
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy
|
||||
) and _get_num_replicas_per_client(strategy) > 1:
|
||||
self.skipTest('b/167331966')
|
||||
|
||||
if strategy.num_replicas_in_sync <= 1:
|
||||
self.skipTest('Test for more than 1 replica only.')
|
||||
|
||||
def value_fn(ctx):
|
||||
return constant_op.constant(
|
||||
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
||||
@ -571,7 +587,8 @@ class AllGatherTest(test.TestCase, parameterized.TestCase):
|
||||
"""Different rank: [1,], [1, 2] -> raise error."""
|
||||
if strategy.num_replicas_in_sync <= 1:
|
||||
self.skipTest('Test for more than 1 replicas.')
|
||||
if _get_num_devices_per_worker(strategy) > 1:
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy
|
||||
) and _get_num_replicas_per_client(strategy) > 1:
|
||||
self.skipTest('b/167331966')
|
||||
def value_fn(ctx):
|
||||
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
|
||||
@ -601,10 +618,12 @@ def _make_indexed_slices(values, indices, dense_shape):
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_num_devices_per_worker(strategy):
|
||||
"""Returns the number of workers in the current cluster for multi-worker."""
|
||||
resolver = strategy.cluster_resolver
|
||||
return max(nest.flatten(resolver.num_accelerators())[0], 1)
|
||||
def _get_num_replicas_per_client(strategy):
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy):
|
||||
resolver = strategy.cluster_resolver
|
||||
return max(nest.flatten(resolver.num_accelerators())[0], 1)
|
||||
else:
|
||||
return strategy.num_replicas_in_sync
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user