TPUStrategy all_gather does not need custom gradient since the underlying AllToAll has customized gradient. PiperOrigin-RevId: 338233794 Change-Id: I6b4d1df26343cadc0508cea7658951b86154683d
667 lines
26 KiB
Python
667 lines
26 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for common methods in strategy classes."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute import central_storage_strategy
|
|
from tensorflow.python.distribute import combinations
|
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
|
from tensorflow.python.distribute import mirrored_strategy
|
|
from tensorflow.python.distribute import strategy_combinations
|
|
from tensorflow.python.distribute import test_util
|
|
from tensorflow.python.distribute import tpu_strategy
|
|
from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
strategy=[
|
|
strategy_combinations.default_strategy,
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.one_device_strategy_gpu,
|
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
strategy_combinations.mirrored_strategy_with_one_gpu,
|
|
strategy_combinations.mirrored_strategy_with_two_gpus,
|
|
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
|
],
|
|
mode=['eager'],
|
|
pure_eager=[True, False]) + combinations.combine(
|
|
strategy=[
|
|
strategy_combinations.tpu_strategy,
|
|
strategy_combinations.tpu_strategy_packed_var,
|
|
strategy_combinations.tpu_strategy_one_step,
|
|
strategy_combinations.cloud_tpu_strategy,
|
|
],
|
|
mode=['eager'],
|
|
pure_eager=[False]))
|
|
class GatherTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager,
|
|
strategy):
|
|
distributed_values = strategy.experimental_distribute_values_from_function(
|
|
lambda _: array_ops.identity(value_on_replica))
|
|
|
|
def run():
|
|
return strategy.gather(distributed_values, axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
all_results = [
|
|
value_on_replica for _ in range(strategy.num_replicas_in_sync)
|
|
]
|
|
expected_result = array_ops.concat(all_results, axis=axis)
|
|
self.assertAllEqual(expected_result, run().numpy())
|
|
|
|
def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6]."""
|
|
single_value = constant_op.constant([1, 2, 3])
|
|
axis = 0
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3]."""
|
|
single_value = constant_op.constant([[1, 2, 3]])
|
|
axis = 0
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6]."""
|
|
single_value = constant_op.constant([[1, 2, 3]])
|
|
axis = 1
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2]."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 0
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2]."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 1
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
|
|
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4]."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 2
|
|
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
|
|
|
|
def testGatherDiffShapeAtAxis0(self, strategy, pure_eager):
|
|
"""Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1]."""
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(ctx.replica_id_in_sync_group + 1, 1))
|
|
|
|
distributed_values = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
axis = 0
|
|
|
|
def run():
|
|
return strategy.gather(distributed_values, axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
expected_result = constant_op.constant(
|
|
1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
|
|
|
|
self.assertAllEqual(expected_result, run().numpy())
|
|
|
|
def testGatherDiffShapeAtAxis1(self, strategy, pure_eager):
|
|
"""Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
|
|
|
distributed_values = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
axis = 1
|
|
|
|
def run():
|
|
return strategy.gather(distributed_values, axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
expected_result = constant_op.constant(
|
|
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
|
|
|
|
self.assertAllEqual(expected_result, run().numpy())
|
|
|
|
def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
|
|
"""Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
|
|
if isinstance(strategy, CollectiveAllReduceStrategy
|
|
) and _get_num_replicas_per_client(strategy) > 1:
|
|
self.skipTest('b/167331966')
|
|
|
|
if strategy.num_replicas_in_sync <= 1:
|
|
self.skipTest('Test for more than 1 replica only.')
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
|
|
|
distributed_values = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
axis = 0
|
|
|
|
def run():
|
|
return strategy.gather(distributed_values, axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
if isinstance(strategy, CollectiveAllReduceStrategy):
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Shape mismatch'):
|
|
run()
|
|
elif isinstance(
|
|
strategy,
|
|
(mirrored_strategy.MirroredStrategy,
|
|
central_storage_strategy.CentralStorageStrategy)) and pure_eager:
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Dimensions of inputs should match'):
|
|
run()
|
|
else:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Dimension \d in both shapes must be equal'):
|
|
run()
|
|
|
|
def testGatherRaiseSparse(self, strategy, pure_eager):
|
|
dense_shape = [5, 2]
|
|
t0 = _make_indexed_slices(
|
|
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
|
|
|
|
def run(value):
|
|
return strategy.gather(value, axis=0)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r'gather does not support IndexedSlices'):
|
|
if pure_eager:
|
|
run(t0)
|
|
else:
|
|
def_function.function(run)(t0)
|
|
|
|
def testGatherRaiseDifferentRank(self, strategy, pure_eager):
|
|
"""Different rank: [1,], [1, 2] -> raise error."""
|
|
if strategy.num_replicas_in_sync <= 1:
|
|
self.skipTest('Test for more than 1 replicas.')
|
|
if isinstance(strategy, CollectiveAllReduceStrategy
|
|
) and _get_num_replicas_per_client(strategy) > 1:
|
|
self.skipTest('b/167331966')
|
|
def value_fn(ctx):
|
|
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
|
|
|
|
distributed_values = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
axis = 0
|
|
|
|
def run():
|
|
return strategy.gather(distributed_values, axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
if isinstance(strategy, CollectiveAllReduceStrategy):
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Shape mismatch'):
|
|
run()
|
|
elif isinstance(
|
|
strategy,
|
|
(mirrored_strategy.MirroredStrategy,
|
|
central_storage_strategy.CentralStorageStrategy)) and pure_eager:
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Ranks of all input tensors should match'):
|
|
run()
|
|
elif _is_tpu_strategy(strategy) and pure_eager:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Dimension \d in both shapes must be equal'):
|
|
run()
|
|
else:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Shape must be rank \d but is rank \d'):
|
|
run()
|
|
|
|
# Ideally, here we should split them into another test class, AllGatherTest.
|
|
# But doing that makes two initialize_tpu_system() calls and one of them times
|
|
# out, on Kokoro. Integrating two into one avoids it.
|
|
def _all_gather_same_shape_and_verify(self, value_on_replica, axis,
|
|
pure_eager, strategy):
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
lambda _: array_ops.identity(value_on_replica))
|
|
|
|
def replica_fn(per_replica_value):
|
|
ctx = ds_context.get_replica_context()
|
|
local_value = array_ops.identity(per_replica_value)
|
|
return ctx.all_gather(local_value, axis=axis)
|
|
|
|
if not pure_eager:
|
|
replica_fn = def_function.function(replica_fn)
|
|
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(replica_fn, args=(per_replica_value,)))
|
|
|
|
all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
|
|
expect = array_ops.concat(all_value, axis=axis)
|
|
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
|
|
|
self.assertAllClose(expected_result, result)
|
|
|
|
def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,)."""
|
|
single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
|
|
axis = 0
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3)."""
|
|
single_value = constant_op.constant([[1, 2, 3]])
|
|
axis = 0
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6)."""
|
|
single_value = constant_op.constant([[1, 2, 3]])
|
|
axis = 1
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2)."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 0
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2)."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 1
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4)."""
|
|
single_value = constant_op.constant([[[1, 2], [1, 2]]])
|
|
axis = 2
|
|
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
|
|
strategy)
|
|
|
|
def testAllGatherDiffValueTPU(self, strategy, pure_eager):
|
|
# Test for TPU only since it can't be tested via testAllGatherDiffShape*
|
|
if not _is_tpu_strategy(strategy):
|
|
self.skipTest('Test for TPU only. For other strategies case already'
|
|
' covered in other tests')
|
|
|
|
data = [[1], [2], [3], [4], [5], [6], [7], [8]]
|
|
|
|
axis = 0
|
|
dataset = dataset_ops.DatasetV2.from_tensor_slices(data).batch(8)
|
|
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
|
|
|
@def_function.function
|
|
def replica_fn(per_replica_value):
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
|
|
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(replica_fn, args=(next(input_iterator),)))
|
|
|
|
expected_result = [data] * _get_num_replicas_per_client(strategy)
|
|
self.assertAllClose(expected_result, result)
|
|
|
|
def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager):
|
|
"""Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1]."""
|
|
|
|
if _is_tpu_strategy(strategy):
|
|
self.skipTest('TPU does not support all_gather different shapes')
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(ctx.replica_id_in_sync_group + 1, 1))
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
|
|
expect = constant_op.constant(
|
|
1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
|
|
|
|
def run(value):
|
|
value_identity = array_ops.identity(value)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(value_identity, axis=0)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(run, args=(per_replica_value,)))
|
|
self.assertAllEqual(expected_result, result)
|
|
|
|
def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager):
|
|
"""Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
|
|
if _is_tpu_strategy(strategy):
|
|
self.skipTest('TPU does not support all_gather different shapes')
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
|
|
expect = constant_op.constant(
|
|
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
|
|
|
|
def run(value):
|
|
value_identity = array_ops.identity(value)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(value_identity, axis=1)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
expected_result = [expect] * _get_num_replicas_per_client(strategy)
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(run, args=(per_replica_value,)))
|
|
self.assertAllEqual(expected_result, result)
|
|
|
|
def testAllGatherNest(self, strategy, pure_eager):
|
|
if _is_tpu_strategy(strategy):
|
|
self.skipTest('TPU does not support all_gather different shapes')
|
|
|
|
axis = 1
|
|
|
|
def value_fn(ctx):
|
|
value = constant_op.constant(
|
|
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
|
return value
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
|
|
expect_1 = constant_op.constant(
|
|
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
|
|
|
|
expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)
|
|
|
|
value_2 = constant_op.constant([[[1, 2], [1, 2]]])
|
|
|
|
expect_2 = array_ops.concat(
|
|
[value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
|
|
|
|
expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)
|
|
|
|
def run(value):
|
|
value_1 = array_ops.identity(value)
|
|
value_3 = array_ops.identity(value_2)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather([value_1, value_3], axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
result = strategy.run(run, args=(per_replica_value,))
|
|
self.assertAllEqual(expected_per_replica_1,
|
|
strategy.experimental_local_results(result[0]))
|
|
self.assertAllEqual(expected_per_replica_2,
|
|
strategy.experimental_local_results(result[1]))
|
|
|
|
def testAllGatherNest1D0Axis(self, strategy, pure_eager):
|
|
"""all_gather(..., axis=0,...) a nest of DistributedValues."""
|
|
single_value = constant_op.constant([1, 2, 3])
|
|
axis = 0
|
|
|
|
def run():
|
|
value_identity = array_ops.identity(single_value)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather([value_identity, value_identity], axis=axis)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
|
|
expect = array_ops.concat(all_value, axis=axis)
|
|
expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)
|
|
|
|
result = strategy.run(run)
|
|
for gathered_result in result:
|
|
self.assertAllEqual(expected_per_replica,
|
|
strategy.experimental_local_results(gathered_result))
|
|
|
|
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
|
|
"""Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
|
|
if _is_tpu_strategy(strategy):
|
|
self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
|
|
|
|
if isinstance(strategy, CollectiveAllReduceStrategy
|
|
) and _get_num_replicas_per_client(strategy) > 1:
|
|
self.skipTest('b/167331966')
|
|
|
|
if strategy.num_replicas_in_sync <= 1:
|
|
self.skipTest('Test for more than 1 replica only.')
|
|
|
|
def value_fn(ctx):
|
|
return constant_op.constant(
|
|
1, shape=(1, ctx.replica_id_in_sync_group + 1))
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
|
|
def run(value):
|
|
value_identity = array_ops.identity(value)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(value_identity, axis=0)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
if isinstance(strategy, CollectiveAllReduceStrategy):
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Shape mismatch'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
elif isinstance(
|
|
strategy,
|
|
(mirrored_strategy.MirroredStrategy,
|
|
central_storage_strategy.CentralStorageStrategy)) and pure_eager:
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Dimensions of inputs should match'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
else:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Dimension \d in both shapes must be equal'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
|
|
def testAllGatherRaiseSparse(self, strategy, pure_eager):
|
|
dense_shape = [5, 2]
|
|
t0 = _make_indexed_slices(
|
|
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
|
|
|
|
def replica_fn(value):
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(value, axis=0)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r'all_gather does not support IndexedSlices'):
|
|
if not pure_eager:
|
|
strategy.run(def_function.function(replica_fn), args=(t0,))
|
|
else:
|
|
strategy.run(replica_fn, args=(t0,))
|
|
|
|
def testAllGatherRaiseDifferentRank(self, strategy, pure_eager):
|
|
"""Different rank: [1,], [1, 2] -> raise error."""
|
|
if _is_tpu_strategy(strategy):
|
|
self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
|
|
|
|
if strategy.num_replicas_in_sync <= 1:
|
|
self.skipTest('Test for more than 1 replicas.')
|
|
if isinstance(strategy, CollectiveAllReduceStrategy
|
|
) and _get_num_replicas_per_client(strategy) > 1:
|
|
self.skipTest('b/167331966')
|
|
def value_fn(ctx):
|
|
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
|
|
def run(value):
|
|
value_identity = array_ops.identity(value)
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(value_identity, axis=0)
|
|
|
|
if not pure_eager:
|
|
run = def_function.function(run)
|
|
|
|
if isinstance(strategy, CollectiveAllReduceStrategy):
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Shape mismatch'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
elif isinstance(strategy,
|
|
(mirrored_strategy.MirroredStrategy,
|
|
central_storage_strategy.CentralStorageStrategy)):
|
|
if pure_eager:
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
|
r'Ranks of all input tensors should match'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
else:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Shape must be rank \d but is rank \d'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
else:
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Dimension \d in both shapes must be equal'):
|
|
strategy.run(run, args=(per_replica_value,))
|
|
|
|
def testAllGatherGradient(self, strategy, pure_eager):
|
|
if pure_eager:
|
|
self.skipTest('`tf.gradients` is not supported with eager execution '
|
|
'without using tf.functions.')
|
|
|
|
def all_gather_fn(value):
|
|
axis = 1
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(array_ops.identity(value), axis)
|
|
|
|
gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
|
|
gradient = [[gradient_comp], [gradient_comp]]
|
|
grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
|
|
|
|
@def_function.function
|
|
def step(c):
|
|
x = constant_op.constant([[3.], [5.]])
|
|
mid = all_gather_fn(x)
|
|
y = mid * c
|
|
return gradients_impl.gradients_v2(y, [x])[0]
|
|
|
|
def value_fn(ctx):
|
|
x = [1., 2., 3., 4., 5., 6., 7., 8.]
|
|
return array_ops.constant([x[ctx.replica_id_in_sync_group]])
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(step, args=(per_replica_value,)))
|
|
|
|
self.assertAllEqual(grads_for_all_replicas, result)
|
|
|
|
def testAllGatherGradientNest(self, strategy, pure_eager):
|
|
if pure_eager:
|
|
self.skipTest('`tf.gradients` is not supported with eager execution '
|
|
'without using tf.functions.')
|
|
|
|
def all_gather_fn(value):
|
|
axis = 1
|
|
ctx = ds_context.get_replica_context()
|
|
return ctx.all_gather(array_ops.identity(value), axis)
|
|
|
|
gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
|
|
gradient = [[gradient_comp], [gradient_comp]]
|
|
grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
|
|
|
|
@def_function.function
|
|
def step(c):
|
|
x = constant_op.constant([[3.], [5.]])
|
|
y = constant_op.constant([[2.], [4.]])
|
|
mid = all_gather_fn([x, y])
|
|
y = mid * c
|
|
return gradients_impl.gradients_v2(y, [x])[0]
|
|
|
|
def value_fn(ctx):
|
|
x = [1., 2., 3., 4., 5., 6., 7., 8.]
|
|
return array_ops.constant([x[ctx.replica_id_in_sync_group]])
|
|
|
|
per_replica_value = strategy.experimental_distribute_values_from_function(
|
|
value_fn)
|
|
result = strategy.experimental_local_results(
|
|
strategy.run(step, args=(per_replica_value,)))
|
|
|
|
self.assertAllEqual(grads_for_all_replicas, result)
|
|
|
|
|
|
def _make_indexed_slices(values, indices, dense_shape):
|
|
tensor = ops.IndexedSlices(
|
|
values=constant_op.constant(values),
|
|
indices=constant_op.constant(indices),
|
|
dense_shape=constant_op.constant(dense_shape))
|
|
return tensor
|
|
|
|
|
|
def _get_num_replicas_per_client(strategy):
|
|
if isinstance(strategy, CollectiveAllReduceStrategy):
|
|
resolver = strategy.cluster_resolver
|
|
return max(nest.flatten(resolver.num_accelerators())[0], 1)
|
|
else:
|
|
return strategy.num_replicas_in_sync
|
|
|
|
|
|
def _is_tpu_strategy(strategy):
|
|
return isinstance(strategy,
|
|
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
|
|
tpu_strategy.TPUStrategyV2))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_util.main()
|