Add gather() util for multi worker tests
It takes a PerReplica value of n-dim, and returns a tensor of (n+1)-dim. PiperOrigin-RevId: 318588244 Change-Id: I821f7701a6066e178db54e12a66e773c1767314c
This commit is contained in:
parent
4c268edb3c
commit
897803b00c
@ -20,6 +20,7 @@ py_library(
|
|||||||
":single_loss_example",
|
":single_loss_example",
|
||||||
":strategy_combinations",
|
":strategy_combinations",
|
||||||
":strategy_test_lib",
|
":strategy_test_lib",
|
||||||
|
":test_util",
|
||||||
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
|
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
|
||||||
"//tensorflow/python/keras/distribute:keras_test_lib",
|
"//tensorflow/python/keras/distribute:keras_test_lib",
|
||||||
],
|
],
|
||||||
@ -1858,3 +1859,38 @@ distribute_py_test(
|
|||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "test_util",
|
||||||
|
srcs = ["test_util.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":collective_all_reduce_strategy",
|
||||||
|
":cross_device_utils",
|
||||||
|
":distribute_utils",
|
||||||
|
":values",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
distribute_py_test(
|
||||||
|
name = "test_util_test",
|
||||||
|
srcs = ["test_util_test.py"],
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":combinations",
|
||||||
|
":strategy_combinations",
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
72
tensorflow/python/distribute/test_util.py
Normal file
72
tensorflow/python/distribute/test_util.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test utilities."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
|
from tensorflow.python.distribute import cross_device_utils
|
||||||
|
from tensorflow.python.distribute import distribute_utils
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.types import core
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
def gather(strategy, value):
|
||||||
|
"""Gathers value from all workers.
|
||||||
|
|
||||||
|
This is intended for tests before we implement an official all-gather API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: a `tf.distribute.Strategy`.
|
||||||
|
value: a nested structure of n-dim `tf.distribute.DistributedValue` of
|
||||||
|
`tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a (n+1)-dim `tf.Tensor`.
|
||||||
|
"""
|
||||||
|
return nest.map_structure(functools.partial(_gather, strategy), value)
|
||||||
|
|
||||||
|
|
||||||
|
def _gather(strategy, value):
|
||||||
|
"""Gathers a single value."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if not isinstance(value, values.DistributedValues):
|
||||||
|
assert isinstance(value, core.Tensor)
|
||||||
|
value = values.PerReplica([value])
|
||||||
|
if not isinstance(strategy.extended,
|
||||||
|
collective_all_reduce_strategy.CollectiveAllReduceExtended):
|
||||||
|
return array_ops.stack(value._values)
|
||||||
|
assert len(strategy.extended.worker_devices) == len(value._values)
|
||||||
|
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
|
||||||
|
collective_keys = strategy.extended._collective_keys
|
||||||
|
devices = strategy.extended.worker_devices
|
||||||
|
group_size = strategy.num_replicas_in_sync
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def gather_fn():
|
||||||
|
gathered = cross_device_utils.build_collective_gather(
|
||||||
|
inputs, devices, group_size, collective_keys)
|
||||||
|
return distribute_utils.update_regroup(
|
||||||
|
strategy.extended, gathered, group=True)
|
||||||
|
|
||||||
|
return gather_fn()
|
||||||
|
# pylint: enable=protected-access
|
75
tensorflow/python/distribute/test_util_test.py
Normal file
75
tensorflow/python/distribute/test_util_test.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for test utilities."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
from tensorflow.python.distribute import test_util
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strategy=[
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||||
|
] + strategy_combinations.strategies_minus_tpu,
|
||||||
|
mode=['eager', 'graph']))
|
||||||
|
class GatherTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testOne(self, strategy):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
return array_ops.ones((), dtypes.float32)
|
||||||
|
|
||||||
|
results = test_util.gather(strategy, strategy.run(f))
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(results), [1.] * strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
def testNest(self, strategy):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
return {
|
||||||
|
'foo':
|
||||||
|
array_ops.ones((), dtypes.float32),
|
||||||
|
'bar': [
|
||||||
|
array_ops.zeros((), dtypes.float32),
|
||||||
|
array_ops.ones((), dtypes.float32),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
results = test_util.gather(strategy, strategy.run(f))
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync)
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync)
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
combinations.main()
|
Loading…
Reference in New Issue
Block a user