diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 9b9d951aa86..1dc081c55f6 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -20,6 +20,7 @@ py_library( ":single_loss_example", ":strategy_combinations", ":strategy_test_lib", + ":test_util", "//tensorflow/python/keras/distribute:keras_correctness_test_lib", "//tensorflow/python/keras/distribute:keras_test_lib", ], @@ -1858,3 +1859,38 @@ distribute_py_test( "//tensorflow/python/eager:test", ], ) + +py_library( + name = "test_util", + srcs = ["test_util.py"], + srcs_version = "PY3", + deps = [ + ":collective_all_reduce_strategy", + ":cross_device_utils", + ":distribute_utils", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/types", + ], +) + +distribute_py_test( + name = "test_util_test", + srcs = ["test_util_test.py"], + tags = [ + "multi_and_single_gpu", + ], + deps = [ + ":combinations", + ":strategy_combinations", + ":test_util", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py new file mode 100644 index 00000000000..8ab054ac63e --- /dev/null +++ b/tensorflow/python/distribute/test_util.py @@ -0,0 +1,72 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import values +from tensorflow.python.eager import def_function +from tensorflow.python.ops import array_ops +from tensorflow.python.types import core +from tensorflow.python.util import nest + + +def gather(strategy, value): + """Gathers value from all workers. + + This is intended for tests before we implement an official all-gather API. + + Args: + strategy: a `tf.distribute.Strategy`. + value: a nested structure of n-dim `tf.distribute.DistributedValue` of + `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. + + Returns: + a (n+1)-dim `tf.Tensor`. + """ + return nest.map_structure(functools.partial(_gather, strategy), value) + + +def _gather(strategy, value): + """Gathers a single value.""" + # pylint: disable=protected-access + if not isinstance(value, values.DistributedValues): + assert isinstance(value, core.Tensor) + value = values.PerReplica([value]) + if not isinstance(strategy.extended, + collective_all_reduce_strategy.CollectiveAllReduceExtended): + return array_ops.stack(value._values) + assert len(strategy.extended.worker_devices) == len(value._values) + inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] + collective_keys = strategy.extended._collective_keys + devices = strategy.extended.worker_devices + group_size = strategy.num_replicas_in_sync + + @def_function.function + def gather_fn(): + gathered = cross_device_utils.build_collective_gather( + inputs, devices, group_size, collective_keys) + return distribute_utils.update_regroup( + strategy.extended, gathered, group=True) + + return gather_fn() + # pylint: enable=protected-access diff --git a/tensorflow/python/distribute/test_util_test.py b/tensorflow/python/distribute/test_util_test.py new file mode 100644 index 00000000000..7dab2e199b1 --- /dev/null +++ b/tensorflow/python/distribute/test_util_test.py @@ -0,0 +1,75 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for test utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ] + strategy_combinations.strategies_minus_tpu, + mode=['eager', 'graph'])) +class GatherTest(test.TestCase, parameterized.TestCase): + + def testOne(self, strategy): + + @def_function.function + def f(): + return array_ops.ones((), dtypes.float32) + + results = test_util.gather(strategy, strategy.run(f)) + self.assertAllEqual( + self.evaluate(results), [1.] * strategy.num_replicas_in_sync) + + def testNest(self, strategy): + + @def_function.function + def f(): + return { + 'foo': + array_ops.ones((), dtypes.float32), + 'bar': [ + array_ops.zeros((), dtypes.float32), + array_ops.ones((), dtypes.float32), + ] + } + + results = test_util.gather(strategy, strategy.run(f)) + self.assertAllEqual( + self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync) + self.assertAllEqual( + self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync) + self.assertAllEqual( + self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync) + + +if __name__ == '__main__': + combinations.main()