Add gather() util for multi worker tests

It takes a PerReplica value of n-dim, and returns a tensor of (n+1)-dim.

PiperOrigin-RevId: 318588244
Change-Id: I821f7701a6066e178db54e12a66e773c1767314c
This commit is contained in:
Ran Chen 2020-06-26 20:20:16 -07:00 committed by TensorFlower Gardener
parent 4c268edb3c
commit 897803b00c
3 changed files with 183 additions and 0 deletions

View File

@ -20,6 +20,7 @@ py_library(
":single_loss_example", ":single_loss_example",
":strategy_combinations", ":strategy_combinations",
":strategy_test_lib", ":strategy_test_lib",
":test_util",
"//tensorflow/python/keras/distribute:keras_correctness_test_lib", "//tensorflow/python/keras/distribute:keras_correctness_test_lib",
"//tensorflow/python/keras/distribute:keras_test_lib", "//tensorflow/python/keras/distribute:keras_test_lib",
], ],
@ -1858,3 +1859,38 @@ distribute_py_test(
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
], ],
) )
py_library(
name = "test_util",
srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [
":collective_all_reduce_strategy",
":cross_device_utils",
":distribute_utils",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/types",
],
)
distribute_py_test(
name = "test_util_test",
srcs = ["test_util_test.py"],
tags = [
"multi_and_single_gpu",
],
deps = [
":combinations",
":strategy_combinations",
":test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,72 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.ops import array_ops
from tensorflow.python.types import core
from tensorflow.python.util import nest
def gather(strategy, value):
"""Gathers value from all workers.
This is intended for tests before we implement an official all-gather API.
Args:
strategy: a `tf.distribute.Strategy`.
value: a nested structure of n-dim `tf.distribute.DistributedValue` of
`tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
Returns:
a (n+1)-dim `tf.Tensor`.
"""
return nest.map_structure(functools.partial(_gather, strategy), value)
def _gather(strategy, value):
"""Gathers a single value."""
# pylint: disable=protected-access
if not isinstance(value, values.DistributedValues):
assert isinstance(value, core.Tensor)
value = values.PerReplica([value])
if not isinstance(strategy.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended):
return array_ops.stack(value._values)
assert len(strategy.extended.worker_devices) == len(value._values)
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
collective_keys = strategy.extended._collective_keys
devices = strategy.extended.worker_devices
group_size = strategy.num_replicas_in_sync
@def_function.function
def gather_fn():
gathered = cross_device_utils.build_collective_gather(
inputs, devices, group_size, collective_keys)
return distribute_utils.update_regroup(
strategy.extended, gathered, group=True)
return gather_fn()
# pylint: enable=protected-access

View File

@ -0,0 +1,75 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for test utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
] + strategy_combinations.strategies_minus_tpu,
mode=['eager', 'graph']))
class GatherTest(test.TestCase, parameterized.TestCase):
def testOne(self, strategy):
@def_function.function
def f():
return array_ops.ones((), dtypes.float32)
results = test_util.gather(strategy, strategy.run(f))
self.assertAllEqual(
self.evaluate(results), [1.] * strategy.num_replicas_in_sync)
def testNest(self, strategy):
@def_function.function
def f():
return {
'foo':
array_ops.ones((), dtypes.float32),
'bar': [
array_ops.zeros((), dtypes.float32),
array_ops.ones((), dtypes.float32),
]
}
results = test_util.gather(strategy, strategy.run(f))
self.assertAllEqual(
self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync)
self.assertAllEqual(
self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync)
self.assertAllEqual(
self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)
if __name__ == '__main__':
combinations.main()