Graduate experimental_hints to options in all_reduce/reduce/batch_reduce
The CollectiveHints class is also renamed to CommunicationOptions. The communication enum is added to it. CommunicationOptions stays experimental since the detailed options may change, but it's rather clear we need an options argument for these cross device communications. PiperOrigin-RevId: 337547832 Change-Id: I376171672698d5923b4e52f2567d4a584c8e21b6
This commit is contained in:
parent
9ef7492f43
commit
f196a243ea
@ -54,6 +54,13 @@
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
|
||||
`tf.distribute.StrategyExtended.batch_reduce_to`,
|
||||
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
|
||||
`tf.distribute.experimental.CollectiveHints` is renamed
|
||||
`tf.distribute.experimental.CommunicationOptions`.
|
||||
`tf.distribute.experimental.CollectiveCommunication` is renamed
|
||||
`tf.distribute.experimental.CommunicationImplementation`.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
@ -337,6 +337,7 @@ py_library(
|
||||
name = "mirrored_strategy",
|
||||
srcs = ["mirrored_strategy.py"],
|
||||
deps = [
|
||||
":collective_util",
|
||||
":cross_device_ops",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
@ -425,6 +426,7 @@ py_library(
|
||||
srcs = ["collective_all_reduce_strategy.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":collective_util",
|
||||
":cross_device_ops",
|
||||
":cross_device_utils",
|
||||
":input_lib",
|
||||
@ -669,12 +671,22 @@ py_library(
|
||||
py_library(
|
||||
name = "collective_util",
|
||||
srcs = ["collective_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "collective_util_test",
|
||||
srcs = ["collective_util_test.py"],
|
||||
deps = [
|
||||
":collective_util",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "shared_variable_creator",
|
||||
srcs = ["shared_variable_creator.py"],
|
||||
|
@ -98,33 +98,34 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
Tensorflow API.
|
||||
|
||||
"""
|
||||
|
||||
# TODO(anjalisridhar): Update our guides with examples showing how we can use
|
||||
# the cluster_resolver argument.
|
||||
|
||||
# The starting number for collective keys. This should only be set in tests.
|
||||
_collective_key_base = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
|
||||
cluster_resolver=None):
|
||||
def __init__(self,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO,
|
||||
cluster_resolver=None):
|
||||
"""Creates the strategy.
|
||||
|
||||
Args:
|
||||
communication: optional
|
||||
`tf.distribute.experimental.CollectiveCommunication`. This is a hint on
|
||||
the preferred collective communication implementation. Possible values
|
||||
include `AUTO`, `RING`, and `NCCL`.
|
||||
`tf.distribute.experimental.CommunicationImplemenation`. This is a hint
|
||||
on the preferred collective communication implementation. Possible
|
||||
values include `AUTO`, `RING`, and `NCCL`.
|
||||
cluster_resolver: optional
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
|
||||
"""
|
||||
# TODO(b/150151677): consider move communication to CollectiveHints.
|
||||
communication_options = collective_util.Options(
|
||||
implementation=communication)
|
||||
super(CollectiveAllReduceStrategy, self).__init__(
|
||||
CollectiveAllReduceExtended(
|
||||
self,
|
||||
communication=communication,
|
||||
cluster_resolver=cluster_resolver))
|
||||
cluster_resolver=cluster_resolver,
|
||||
communication_options=communication_options))
|
||||
|
||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
||||
"MultiWorkerMirroredStrategy")
|
||||
@ -138,7 +139,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
def _from_local_devices(
|
||||
cls,
|
||||
devices,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.AUTO):
|
||||
communication=collective_util.CommunicationImplemenation.AUTO):
|
||||
"""A convenience method to create an object with a list of devices."""
|
||||
obj = cls(communication)
|
||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||
@ -162,16 +163,17 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||
|
||||
__doc__ = CollectiveAllReduceStrategy.__doc__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
|
||||
cluster_resolver=None):
|
||||
def __init__(self,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO,
|
||||
cluster_resolver=None):
|
||||
"""Initializes the object."""
|
||||
communication_options = collective_util.Options(
|
||||
implementation=communication)
|
||||
super(CollectiveAllReduceStrategyV1, self).__init__(
|
||||
CollectiveAllReduceExtended(
|
||||
self,
|
||||
communication=communication,
|
||||
cluster_resolver=cluster_resolver))
|
||||
cluster_resolver=cluster_resolver,
|
||||
communication_options=communication_options))
|
||||
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
|
||||
"MultiWorkerMirroredStrategy")
|
||||
# pylint: disable=protected-access
|
||||
@ -196,16 +198,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
# Times to retry before considering the peer is down.
|
||||
_check_health_retry_limit = 3
|
||||
|
||||
def __init__(self,
|
||||
container_strategy,
|
||||
communication,
|
||||
cluster_resolver):
|
||||
def __init__(self, container_strategy, cluster_resolver,
|
||||
communication_options):
|
||||
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
||||
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
|
||||
assert isinstance(
|
||||
communication,
|
||||
cross_device_ops_lib.CollectiveCommunication)
|
||||
self._communication = communication
|
||||
self._communication_options = communication_options
|
||||
self._initialize_strategy(self._cluster_resolver)
|
||||
self._cfer_fn_cache = weakref.WeakKeyDictionary()
|
||||
self.experimental_enable_get_next_as_optional = True
|
||||
@ -255,15 +252,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices),
|
||||
collective_keys=self._collective_keys,
|
||||
communication=self._communication)
|
||||
collective_keys=self._collective_keys)
|
||||
# CrossDeviceOps for per host tensors.
|
||||
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=[self._worker_device],
|
||||
group_size=self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.RING,
|
||||
)
|
||||
collective_keys=self._collective_keys)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
|
||||
@ -283,8 +277,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._rpc_layer = cluster_resolver.rpc_layer
|
||||
self._warn_nccl_no_gpu()
|
||||
|
||||
logging.info("Single-worker MultiWorkerMirroredStrategy with local_devices "
|
||||
"= %r, communication = %s", local_devices, self._communication)
|
||||
logging.info(
|
||||
"Single-worker MultiWorkerMirroredStrategy with local_devices "
|
||||
"= %r, communication = %s", local_devices,
|
||||
self._communication_options.implementation)
|
||||
|
||||
def _initialize_multi_worker(self, cluster_resolver):
|
||||
"""Initializes the object for multi-worker training."""
|
||||
@ -371,15 +367,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices) * self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=self._communication)
|
||||
collective_keys=self._collective_keys)
|
||||
# CrossDeviceOps for per host tensors.
|
||||
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=[self._worker_device],
|
||||
group_size=self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.RING,
|
||||
)
|
||||
collective_keys=self._collective_keys)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
|
||||
@ -398,9 +391,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
logging.info(
|
||||
"MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
|
||||
"task_id = %r, num_workers = %r, local_devices = %r, "
|
||||
"communication = %s", cluster_spec.as_dict(), task_type,
|
||||
task_id, self._num_workers, local_devices,
|
||||
self._communication)
|
||||
"communication = %s", cluster_spec.as_dict(), task_type, task_id,
|
||||
self._num_workers, local_devices,
|
||||
self._communication_options.implementation)
|
||||
|
||||
def __del__(self):
|
||||
self._stop_check_health_thread()
|
||||
@ -571,8 +564,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
|
||||
|
||||
if (not ops.executing_eagerly_outside_functions() and
|
||||
self._communication ==
|
||||
cross_device_ops_lib.CollectiveCommunication.NCCL):
|
||||
self._communication_options.implementation ==
|
||||
collective_util.CommunicationImplemenation.NCCL):
|
||||
updated_config.experimental.collective_nccl = True
|
||||
|
||||
if not self._cluster_spec:
|
||||
@ -610,15 +603,14 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
return self._host_cross_device_ops
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access
|
||||
value,
|
||||
destinations=destinations,
|
||||
axis=axis,
|
||||
experimental_hints=experimental_hints)
|
||||
options=options)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
if (isinstance(value, values.Mirrored) and
|
||||
reduce_op == reduce_util.ReduceOp.MEAN):
|
||||
return value
|
||||
@ -642,7 +634,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
options=self._communication_options.merge(options))
|
||||
|
||||
def _check_health(self):
|
||||
while True:
|
||||
@ -704,8 +696,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
reduce_util.ReduceOp.SUM,
|
||||
dummy_value,
|
||||
dummy_value,
|
||||
experimental_hints=collective_util.Hints(
|
||||
timeout_seconds=self._check_health_initial_timeout))
|
||||
options=collective_util.Options(
|
||||
timeout_seconds=self._check_health_initial_timeout,
|
||||
implementation=collective_util.CommunicationImplemenation.RING))
|
||||
if context.is_async():
|
||||
context.async_wait()
|
||||
except errors.DeadlineExceededError:
|
||||
@ -731,8 +724,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
logging.info("check health thread stopped")
|
||||
|
||||
def _warn_nccl_no_gpu(self):
|
||||
if ((self._communication ==
|
||||
cross_device_ops_lib.CollectiveCommunication.NCCL) and
|
||||
if ((self._communication_options.implementation ==
|
||||
collective_util.CommunicationImplemenation.NCCL) and
|
||||
self._num_gpus_per_worker == 0):
|
||||
logging.warning("Enabled NCCL communication but no GPUs detected/"
|
||||
"specified.")
|
||||
|
@ -1,3 +1,4 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -18,9 +19,143 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import enum
|
||||
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(b/170340570): print deprecation warning for CollectiveCommunication.
|
||||
@tf_export("distribute.experimental.CommunicationImplemenation",
|
||||
"distribute.experimental.CollectiveCommunication")
|
||||
class CommunicationImplemenation(enum.Enum):
|
||||
"""Cross device communication implementation.
|
||||
|
||||
Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is
|
||||
deprecated and will be removed in a future version. Use
|
||||
`tf.distribute.experimental.CommunicationImplemenation` instead.
|
||||
|
||||
* `AUTO`: Automatically chosen by Tensorflow.
|
||||
* `RING`: TensorFlow's ring algorithms for all-reduce and
|
||||
all-gather.
|
||||
* `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on
|
||||
GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING.
|
||||
"""
|
||||
AUTO = "AUTO"
|
||||
RING = "RING"
|
||||
NCCL = "NCCL"
|
||||
# TODO(ayushd): add ncclAllGather implementation.
|
||||
|
||||
|
||||
CollectiveCommunication = CommunicationImplemenation
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.CommunicationOptions")
|
||||
class _OptionsExported(object):
|
||||
"""Options for cross device communications like All-reduce.
|
||||
|
||||
This can be passed to methods like
|
||||
`tf.distribute.get_replica_context().all_reduce()` to optimize collective
|
||||
operation performance. Note that these are only hints, which may or may not
|
||||
change the actual behavior. Some options only apply to certain strategy and
|
||||
are ignored by others.
|
||||
|
||||
One common optimization is to break gradients all-reduce into multiple packs
|
||||
so that weight updates can overlap with gradient all-reduce.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
options = tf.distribute.experimental.CommunicationOptions(
|
||||
bytes_per_pack=50 * 1024 * 1024,
|
||||
timeout_seconds=120,
|
||||
implementation=tf.distribute.experimental.CommunicationImplemenation.NCCL)
|
||||
grads = tf.distribute.get_replica_context().all_reduce(
|
||||
'sum', grads, options=options)
|
||||
optimizer.apply_gradients(zip(grads, vars),
|
||||
experimental_aggregate_gradients=False)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return Options.__new__(Options, *args, **kwargs)
|
||||
|
||||
def __init__(self,
|
||||
bytes_per_pack=0,
|
||||
timeout_seconds=None,
|
||||
implementation=CommunicationImplemenation.AUTO):
|
||||
"""Creates a CollectiveHints.
|
||||
|
||||
Args:
|
||||
bytes_per_pack: a non-negative integer. Breaks collective operations into
|
||||
packs of certain size. If it's zero, the value is determined
|
||||
automatically. This only applies to all-reduce with
|
||||
`MultiWorkerMirroredStrategy` currently.
|
||||
timeout_seconds: a float or None, timeout in seconds. If not None, the
|
||||
collective raises `tf.errors.DeadlineExceededError` if it takes longer
|
||||
than this timeout. Zero disables timeout. This can be useful when
|
||||
debugging hanging issues. This should only be used for debugging since
|
||||
it creates a new thread for each collective, i.e. an overhead of
|
||||
`timeout_seconds * num_collectives_per_second` more threads. This only
|
||||
works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
implementation: a `tf.distribute.experimental.CommunicationImplemenation`.
|
||||
This is a hint on the preferred communication implementation. Possible
|
||||
values include `AUTO`, `RING`, and `NCCL`. NCCL is generally more
|
||||
performant for GPU, but doesn't work for CPU. This only works for
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
|
||||
Raises:
|
||||
ValueError: When arguments have invalid value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Options(object):
|
||||
"""Implementation of OptionsInterface."""
|
||||
|
||||
def __init__(self,
|
||||
bytes_per_pack=0,
|
||||
timeout_seconds=None,
|
||||
implementation=CommunicationImplemenation.AUTO):
|
||||
if bytes_per_pack < 0:
|
||||
raise ValueError("bytes_per_pack must be non-negative")
|
||||
if isinstance(implementation, str):
|
||||
implementation = CommunicationImplemenation(implementation.upper())
|
||||
if not isinstance(implementation, CommunicationImplemenation):
|
||||
raise ValueError("implementation should be a "
|
||||
"tf.distribute.experimental.CommunicationImplemenation")
|
||||
self.bytes_per_pack = bytes_per_pack
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.implementation = implementation
|
||||
|
||||
__init__.__doc__ = _OptionsExported.__init__.__doc__
|
||||
|
||||
def merge(self, options):
|
||||
"""Merges with another options and returns a new one.
|
||||
|
||||
Values specified in the `options` takes precedence if they're not the
|
||||
default.
|
||||
|
||||
Args:
|
||||
options: a `tf.distribute.experimental.CollectiveCommunication`.
|
||||
|
||||
Returns:
|
||||
A new `tf.distribute.experimental.CollectiveCommunication`.
|
||||
"""
|
||||
merged = copy.deepcopy(self)
|
||||
if options is None:
|
||||
return merged
|
||||
if options.bytes_per_pack != 0:
|
||||
merged.bytes_per_pack = options.bytes_per_pack
|
||||
if options.timeout_seconds is not None:
|
||||
merged.timeout_seconds = options.timeout_seconds
|
||||
if options.implementation != CommunicationImplemenation.AUTO:
|
||||
merged.implementation = options.implementation
|
||||
return merged
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.CollectiveHints")
|
||||
class Hints(object):
|
||||
"""Hints for collective operations like AllReduce.
|
||||
@ -61,6 +196,12 @@ class Hints(object):
|
||||
|
||||
"""
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "use distribute.experimental.CommunicationOptions instead")
|
||||
def __new__(cls, bytes_per_pack=0, timeout_seconds=None):
|
||||
return Options(
|
||||
bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds)
|
||||
|
||||
def __init__(self, bytes_per_pack=0, timeout_seconds=None):
|
||||
"""Creates a CollectiveHints.
|
||||
|
||||
@ -80,7 +221,4 @@ class Hints(object):
|
||||
Raises:
|
||||
ValueError: When arguments have invalid value.
|
||||
"""
|
||||
if bytes_per_pack < 0:
|
||||
raise ValueError("bytes_per_pack must be non-negative")
|
||||
self.bytes_per_pack = bytes_per_pack
|
||||
self.timeout_seconds = timeout_seconds
|
||||
pass
|
||||
|
41
tensorflow/python/distribute/collective_util_test.py
Normal file
41
tensorflow/python/distribute/collective_util_test.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for utilities for collectives."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class OptionsTest(test.TestCase):
|
||||
|
||||
def testCreateOptionsViaExportedAPI(self):
|
||||
options = collective_util._OptionsExported()
|
||||
self.assertIsInstance(options, collective_util.Options)
|
||||
|
||||
def testCreateOptionsViaHints(self):
|
||||
with self.assertLogs() as cm:
|
||||
options = collective_util.Hints(50, 1)
|
||||
self.assertTrue(any("is deprecated" in msg for msg in cm.output))
|
||||
self.assertIsInstance(options, collective_util.Options)
|
||||
self.assertEqual(options.bytes_per_pack, 50)
|
||||
self.assertEqual(options.timeout_seconds, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
import threading
|
||||
|
||||
import six
|
||||
@ -251,11 +250,7 @@ class CrossDeviceOps(object):
|
||||
# Returns 1 by default, the value may be overridden by sub classes.
|
||||
return 1
|
||||
|
||||
def reduce(self,
|
||||
reduce_op,
|
||||
per_replica_value,
|
||||
destinations,
|
||||
experimental_hints=None):
|
||||
def reduce(self, reduce_op, per_replica_value, destinations, options=None):
|
||||
"""Reduce `per_replica_value` to `destinations`.
|
||||
|
||||
See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
|
||||
@ -272,8 +267,8 @@ class CrossDeviceOps(object):
|
||||
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
||||
to the devices of that variable, and this method doesn't update the
|
||||
variable.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
||||
@ -283,6 +278,8 @@ class CrossDeviceOps(object):
|
||||
`tf.distribute.DistributedValues` or if destinations is not a string,
|
||||
`tf.Variable` or `tf.distribute.DistributedValues`.
|
||||
"""
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
if not isinstance(per_replica_value, value_lib.DistributedValues):
|
||||
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
||||
|
||||
@ -296,16 +293,12 @@ class CrossDeviceOps(object):
|
||||
v = array_ops.identity(per_replica_value.values[0])
|
||||
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
return self.reduce_implementation(reduce_op, per_replica_value,
|
||||
destinations, experimental_hints)
|
||||
destinations, options)
|
||||
|
||||
def _gather(self,
|
||||
per_replica_value,
|
||||
destinations,
|
||||
axis,
|
||||
experimental_hints=None):
|
||||
def _gather(self, per_replica_value, destinations, axis, options=None):
|
||||
"""Gather `per_replica_value` to `destinations`.
|
||||
|
||||
Args:
|
||||
@ -319,8 +312,8 @@ class CrossDeviceOps(object):
|
||||
variable.
|
||||
axis: specifies the dimension to gather along within each replica's
|
||||
tensor.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` or `tf.distribute.DistributedValues`
|
||||
@ -330,8 +323,11 @@ class CrossDeviceOps(object):
|
||||
`tf.distribute.DistributedValues` or if destinations is not a string,
|
||||
`tf.Variable` or `tf.distribute.DistributedValues`.
|
||||
"""
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if isinstance(per_replica_value, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
"IndexedSlices")
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
|
||||
if not isinstance(per_replica_value, value_lib.DistributedValues):
|
||||
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
||||
@ -347,10 +343,10 @@ class CrossDeviceOps(object):
|
||||
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
||||
|
||||
return self._gather_implementation(per_replica_value, destinations, axis,
|
||||
experimental_hints)
|
||||
options)
|
||||
|
||||
def _gather_implementation(self, per_replica_value, destinations, axis,
|
||||
experimental_hints):
|
||||
options):
|
||||
"""Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
|
||||
|
||||
Overriding this method is useful for subclass implementers.
|
||||
@ -366,8 +362,8 @@ class CrossDeviceOps(object):
|
||||
variable.
|
||||
axis: specifies the dimension to gather along within each replica's
|
||||
tensor.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
||||
@ -380,10 +376,7 @@ class CrossDeviceOps(object):
|
||||
raise NotImplementedError(
|
||||
"_gather method must be implemented in descendants.")
|
||||
|
||||
def batch_reduce(self,
|
||||
reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints=None):
|
||||
def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
|
||||
"""Reduce values to destinations in batches.
|
||||
|
||||
See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
|
||||
@ -394,8 +387,8 @@ class CrossDeviceOps(object):
|
||||
combined.
|
||||
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
||||
`tf.distribute.CrossDeviceOps.reduce` for descriptions.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
|
||||
@ -405,6 +398,8 @@ class CrossDeviceOps(object):
|
||||
ValueError: if `value_destination_pairs` is not an iterable of
|
||||
tuples of `tf.distribute.DistributedValues` and destinations.
|
||||
"""
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
# TODO(yuefengz): if destinations are different, split into several
|
||||
# `_batch_reduce` invocations.
|
||||
if not _validate_value_destination_pairs(value_destination_pairs):
|
||||
@ -425,10 +420,10 @@ class CrossDeviceOps(object):
|
||||
for v, _ in value_destination_pairs
|
||||
]
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
options)
|
||||
|
||||
def broadcast(self, tensor, destinations):
|
||||
"""Broadcast `tensor` to `destinations`.
|
||||
@ -451,7 +446,7 @@ class CrossDeviceOps(object):
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
options):
|
||||
"""Implementation of `reduce`.
|
||||
|
||||
Overriding this method is useful for subclass implementers.
|
||||
@ -467,8 +462,8 @@ class CrossDeviceOps(object):
|
||||
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
||||
to the devices of that variable, this method doesn't update the
|
||||
variable.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
||||
@ -483,7 +478,7 @@ class CrossDeviceOps(object):
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
options):
|
||||
"""Implementation of `batch_reduce`.
|
||||
|
||||
Overriding this method is useful for subclass implementers.
|
||||
@ -493,8 +488,8 @@ class CrossDeviceOps(object):
|
||||
combined.
|
||||
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
||||
`reduce` for descriptions.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
|
||||
Returns:
|
||||
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
|
||||
@ -558,8 +553,8 @@ class ReductionToOneDevice(CrossDeviceOps):
|
||||
super(ReductionToOneDevice, self).__init__()
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
del experimental_hints # Unused.
|
||||
options):
|
||||
del options # Unused.
|
||||
if check_destinations(destinations):
|
||||
devices = get_devices_from(destinations)
|
||||
else:
|
||||
@ -573,8 +568,8 @@ class ReductionToOneDevice(CrossDeviceOps):
|
||||
return self.broadcast(reduced, destinations)
|
||||
|
||||
def _gather_implementation(self, per_replica_value, destinations, axis,
|
||||
experimental_hints):
|
||||
del experimental_hints # Unused.
|
||||
options):
|
||||
del options # Unused.
|
||||
if check_destinations(destinations):
|
||||
devices = get_devices_from(destinations)
|
||||
else:
|
||||
@ -587,10 +582,10 @@ class ReductionToOneDevice(CrossDeviceOps):
|
||||
return self.broadcast(gathered, destinations)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
options):
|
||||
return [
|
||||
self.reduce_implementation(
|
||||
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
|
||||
reduce_op, t, destinations=v, options=options)
|
||||
for t, v in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -806,8 +801,8 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
super(AllReduceCrossDeviceOps, self).__init__()
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
del experimental_hints # Unused.
|
||||
options):
|
||||
del options # Unused.
|
||||
if _devices_match(per_replica_value, destinations):
|
||||
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||
else:
|
||||
@ -815,13 +810,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
destinations)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
options):
|
||||
if _all_devices_match(value_destination_pairs):
|
||||
return self._batch_all_reduce(reduce_op,
|
||||
[v[0] for v in value_destination_pairs])
|
||||
else:
|
||||
return [
|
||||
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
||||
self.reduce_implementation(reduce_op, value, dest, options)
|
||||
for value, dest in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -882,13 +877,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
reduce_op, zip(sparse_values, sparse_values))
|
||||
|
||||
def _gather_implementation(self, per_replica_value, destinations, axis,
|
||||
experimental_hints):
|
||||
options):
|
||||
logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not "
|
||||
"supported. Falling back to gather on one device and "
|
||||
"then broadcast. We're working on a more efficient "
|
||||
"implementation.")
|
||||
return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access
|
||||
experimental_hints)
|
||||
options)
|
||||
|
||||
|
||||
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
|
||||
@ -979,20 +974,9 @@ class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
|
||||
num_packs=num_packs)
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.CollectiveCommunication")
|
||||
class CollectiveCommunication(enum.Enum):
|
||||
"""Communication choices for CollectiveOps.
|
||||
|
||||
* `AUTO`: Default to runtime's automatic choices.
|
||||
* `RING`: TensorFlow's ring algorithms for all-reduce and
|
||||
all-gather.
|
||||
* `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for
|
||||
all-gather.
|
||||
"""
|
||||
AUTO = "AUTO"
|
||||
RING = "RING"
|
||||
NCCL = "NCCL"
|
||||
# TODO(ayushd): add ncclAllGather implementation.
|
||||
# TODO(crccw): remove after migrating all callers.
|
||||
CollectiveCommunication = collective_util.CommunicationImplemenation
|
||||
CommunicationImplemenation = collective_util.CommunicationImplemenation
|
||||
|
||||
|
||||
# TODO(yuefengz): support in-graph collective all-reduce.
|
||||
@ -1003,11 +987,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
all workers and then put results on the right destinations.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys=None,
|
||||
communication=CollectiveCommunication.AUTO):
|
||||
def __init__(self, devices, group_size, collective_keys=None):
|
||||
"""Initializes the object.
|
||||
|
||||
Args:
|
||||
@ -1015,7 +995,6 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
group_size: the global group size. For between-graph replicated training
|
||||
it's the total number of devices across all workers.
|
||||
collective_keys: an optional CollectiveKey object.
|
||||
communication: indicates which collective communication to use.
|
||||
"""
|
||||
if group_size % len(devices) > 0:
|
||||
raise ValueError("group_size must be divisible by the number of devices.")
|
||||
@ -1023,7 +1002,6 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
self._group_size = group_size
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
self._communication = communication
|
||||
# This lock guards all collective launches, i.e. calls to
|
||||
# cross_device_utils.build_collectve_*.
|
||||
#
|
||||
@ -1063,10 +1041,10 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
return self._group_size / len(self._devices)
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
options):
|
||||
values_util.mark_as_unsaveable()
|
||||
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
|
||||
experimental_hints)[0]
|
||||
options)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
|
||||
if _devices_match(per_replica_value, destinations):
|
||||
@ -1095,13 +1073,13 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
options):
|
||||
values_util.mark_as_unsaveable()
|
||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||
if all_devices_match:
|
||||
return self._batch_all_reduce(reduce_op,
|
||||
[v[0] for v in value_destination_pairs],
|
||||
experimental_hints)
|
||||
options)
|
||||
else:
|
||||
if not all_devices_match:
|
||||
logging.log_first_n(
|
||||
@ -1109,43 +1087,41 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
"destinations are different.", 10)
|
||||
|
||||
return [
|
||||
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
||||
self.reduce_implementation(reduce_op, value, dest, options)
|
||||
for value, dest in value_destination_pairs
|
||||
]
|
||||
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values, options):
|
||||
"""All reduce algorithm in a batch."""
|
||||
dense_values, dense_indices, sparse_values, sparse_indices = (
|
||||
cross_device_utils.split_by_sparsity(per_replica_values))
|
||||
if dense_values:
|
||||
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
|
||||
experimental_hints)
|
||||
options)
|
||||
else:
|
||||
dense_results = []
|
||||
if sparse_values:
|
||||
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
|
||||
sparse_values,
|
||||
experimental_hints)
|
||||
sparse_values, options)
|
||||
else:
|
||||
sparse_results = []
|
||||
return cross_device_utils.stitch_values(
|
||||
((dense_results, dense_indices), (sparse_results, sparse_indices)))
|
||||
|
||||
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options):
|
||||
"""All-reduce across all workers in a batch."""
|
||||
|
||||
batch_size = len(per_replica_values)
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication = self._communication.value
|
||||
implementation = options.implementation.value
|
||||
# For now, we use NCCL only when batch_size > 1 since we don't have a way to
|
||||
# order NCCL launches. We're hoping that there's only one batched
|
||||
# all-reduce, which is the gradients.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL if and only if we can order collectives deterministically.
|
||||
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
|
||||
communication = CollectiveCommunication.AUTO.value
|
||||
# is NCCL.
|
||||
if (options.implementation == CommunicationImplemenation.NCCL and
|
||||
batch_size == 1):
|
||||
implementation = CommunicationImplemenation.AUTO.value
|
||||
|
||||
# Reverse the lists so that there's better chance that values follows
|
||||
# the order in which they are calculated (e.g. when they're gradients), so
|
||||
@ -1166,15 +1142,15 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
with self._lock:
|
||||
for i in range(len(self._devices)):
|
||||
packs = cross_device_utils.group_by_size(
|
||||
values_by_device[i], experimental_hints.bytes_per_pack)
|
||||
values_by_device[i], options.bytes_per_pack)
|
||||
if not context.executing_eagerly() and i == 0:
|
||||
logging.info(
|
||||
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
||||
"group_size = %d, communication_hint = %s, num_packs = %d",
|
||||
batch_size, len(self._launchers), self._group_size, communication,
|
||||
len(packs))
|
||||
"group_size = %d, implementation = %s, num_packs = %d",
|
||||
batch_size, len(self._launchers), self._group_size,
|
||||
implementation, len(packs))
|
||||
outputs_by_device.append(self._launchers[i].batch_all_reduce(
|
||||
packs, communication, experimental_hints.timeout_seconds))
|
||||
packs, implementation, options.timeout_seconds))
|
||||
|
||||
for e in self._executors:
|
||||
e.wait()
|
||||
@ -1191,8 +1167,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
return list(reversed(mirrored))
|
||||
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options):
|
||||
"""All-reduce IndexedSlices across all workers in a batch."""
|
||||
|
||||
logging.log_first_n(
|
||||
@ -1200,8 +1175,13 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
"%d all-reduces, group_size = %d" %
|
||||
(len(per_replica_values), self._group_size), 10)
|
||||
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication_hint = self._communication.value
|
||||
implementation = options.implementation.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when implementation
|
||||
# is NCCL.
|
||||
if options.implementation == CommunicationImplemenation.NCCL and len(
|
||||
per_replica_values) == 1:
|
||||
implementation = CommunicationImplemenation.AUTO.value
|
||||
|
||||
gathered_values = []
|
||||
with self._lock:
|
||||
@ -1209,8 +1189,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
outputs = []
|
||||
for i in range(len(self._devices)):
|
||||
outputs.append(self._launchers[i].all_reduce_indexed_slices(
|
||||
per_replica.values[i], communication_hint,
|
||||
experimental_hints.timeout_seconds))
|
||||
per_replica.values[i], implementation, options.timeout_seconds))
|
||||
gathered_values.append(outputs)
|
||||
|
||||
mirrored = []
|
||||
@ -1225,10 +1204,9 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
return mirrored
|
||||
|
||||
def _gather_implementation(self, per_replica_value, destinations, axis,
|
||||
experimental_hints):
|
||||
options):
|
||||
all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
|
||||
values_util.mark_as_unsaveable()
|
||||
all_gathered = self._batch_all_gather([per_replica_value], axis,
|
||||
experimental_hints)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
|
||||
if _devices_match(per_replica_value, destinations):
|
||||
@ -1254,16 +1232,23 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access
|
||||
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
|
||||
def _batch_all_gather(self, per_replica_values, axis, experimental_hints):
|
||||
def _batch_all_gather(self, per_replica_values, axis, options):
|
||||
"""all gather multiple per-replica-values."""
|
||||
batch_size = len(per_replica_values)
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication = self._communication.value
|
||||
# Pass options.implementation to the runtime as a communication
|
||||
# implementation hint.
|
||||
implementation = options.implementation.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when implementation
|
||||
# is NCCL.
|
||||
if (options.implementation == CommunicationImplemenation.NCCL and
|
||||
batch_size == 1):
|
||||
implementation = CommunicationImplemenation.AUTO.value
|
||||
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_gather: %d all-gathers, "
|
||||
"num_devices = %d, group_size = %d, communication_hint = %s, " %
|
||||
(batch_size, len(self._devices), self._group_size, communication), 10)
|
||||
"num_devices = %d, group_size = %d, implementation = %s, " %
|
||||
(batch_size, len(self._devices), self._group_size, implementation), 10)
|
||||
|
||||
def compute_gathered_values():
|
||||
gathered_values = []
|
||||
@ -1272,8 +1257,8 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
outputs = []
|
||||
for i in range(len(self._devices)):
|
||||
outputs.append(self._launchers[i].all_gather(
|
||||
per_replica.values[i], axis, communication,
|
||||
experimental_hints.timeout_seconds))
|
||||
per_replica.values[i], axis, implementation,
|
||||
options.timeout_seconds))
|
||||
gathered_values.append(outputs)
|
||||
return gathered_values
|
||||
|
||||
@ -1292,8 +1277,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# distribute_coordinator deep-copies the strategy object, so
|
||||
# CollectiveAllReduce needs to support deep copy as well.
|
||||
collective_keys = copy.deepcopy(self._collective_keys, memo)
|
||||
return CollectiveAllReduce(self._devices, self._group_size, collective_keys,
|
||||
self._communication)
|
||||
return CollectiveAllReduce(self._devices, self._group_size, collective_keys)
|
||||
|
||||
|
||||
def select_cross_device_ops(devices, session_config=None):
|
||||
|
@ -48,7 +48,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
|
||||
CommunicationImplemenation = collective_util.CommunicationImplemenation
|
||||
ReduceOp = reduce_util.ReduceOp
|
||||
IndexedSlicesValue = indexed_slices.IndexedSlicesValue
|
||||
IndexedSlices = indexed_slices.IndexedSlices
|
||||
@ -140,14 +140,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
global_mpr_1p.runner.run(enable_collective_ops)
|
||||
global_mpr_2p.runner.run(enable_collective_ops)
|
||||
|
||||
def make_collective(self, num_processes, gpu_per_process, communication):
|
||||
def make_collective(self, num_processes, gpu_per_process):
|
||||
"""Returns collectives and other info to be used in tests.
|
||||
|
||||
Args:
|
||||
num_processes: an integer indicating the number of processes that
|
||||
participate in the collective.
|
||||
gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
|
||||
communication: one of `CollectiveCommunication`.
|
||||
|
||||
Returns:
|
||||
A tuple of (collective, devices, group_size) where collective is a instance
|
||||
@ -167,7 +166,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
]
|
||||
group_size = num_processes * len(devices)
|
||||
collective = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=devices, group_size=group_size, communication=communication)
|
||||
devices=devices, group_size=group_size)
|
||||
return collective, devices, cluster_resolver.task_id
|
||||
|
||||
def as_list(self, value):
|
||||
@ -202,12 +201,12 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"num_processes",
|
||||
"gpus_per_process",
|
||||
"reduce_op",
|
||||
"communication",
|
||||
"communication_options",
|
||||
"use_scoped_allocator",
|
||||
])
|
||||
RunOptions.__new__.__defaults__ = (["eager",
|
||||
"func_graph"], 2, 0, ReduceOp.SUM,
|
||||
CollectiveCommunication.AUTO, True)
|
||||
collective_util.Options(), True)
|
||||
|
||||
def reduce_and_verify(self, inputs, expect, options):
|
||||
"""Reduce the given `inputs` and verify the output matches `expect`.
|
||||
@ -222,14 +221,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process,
|
||||
options.communication)
|
||||
options.gpus_per_process)
|
||||
|
||||
def reduce_fn():
|
||||
value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx]
|
||||
per_replica_value = make_per_replica_value(value_fn, devices)
|
||||
reduced_values = collective.reduce(options.reduce_op, per_replica_value,
|
||||
per_replica_value)
|
||||
per_replica_value,
|
||||
options.communication_options)
|
||||
reduced_values = self.as_list(reduced_values)
|
||||
self.assertAllEqual(devices, [v.device for v in reduced_values])
|
||||
return [ops.convert_to_tensor(v) for v in reduced_values]
|
||||
@ -261,8 +260,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = (
|
||||
options.use_scoped_allocator)
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process,
|
||||
options.communication)
|
||||
options.gpus_per_process)
|
||||
|
||||
def batch_reduce_fn():
|
||||
batch_size = len(inputs[0])
|
||||
@ -275,7 +273,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
per_replica_value = make_per_replica_value(value_fn, devices)
|
||||
value_dst_pairs.append((per_replica_value, per_replica_value))
|
||||
reduced_values = collective.batch_reduce(options.reduce_op,
|
||||
value_dst_pairs)
|
||||
value_dst_pairs,
|
||||
options.communication_options)
|
||||
reduced_values = [self.as_list(v) for v in reduced_values]
|
||||
for v in reduced_values:
|
||||
self.assertAllEqual(devices, [t.device for t in v])
|
||||
@ -298,20 +297,21 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
communication=[
|
||||
implementation=[
|
||||
# NCCL is only used for batch reduce, so we are not including
|
||||
# NCCL combination here.
|
||||
CollectiveCommunication.AUTO,
|
||||
CollectiveCommunication.RING
|
||||
CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN]))
|
||||
def testAllReduceDense(self, num_processes, required_gpus, communication,
|
||||
def testAllReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
options = self.RunOptions(
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication=communication)
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation))
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [1.0, 2.0, 3.0, 4.0]
|
||||
@ -330,22 +330,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
communication=[
|
||||
implementation=[
|
||||
# NCCL is only used for batch reduce, so we are not including
|
||||
# NCCL combination here.
|
||||
CollectiveCommunication.AUTO,
|
||||
CollectiveCommunication.RING
|
||||
CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM))
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, communication,
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
options = self.RunOptions(
|
||||
mode=["func_graph"], # Sparse reduce is not supported in eager.
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication=communication)
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation))
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [
|
||||
@ -399,17 +400,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
communication=[
|
||||
CollectiveCommunication.AUTO, CollectiveCommunication.RING,
|
||||
CollectiveCommunication.NCCL
|
||||
implementation=[
|
||||
CommunicationImplemenation.AUTO, CommunicationImplemenation.RING,
|
||||
CommunicationImplemenation.NCCL
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
use_scoped_allocator=[True, False]))
|
||||
def testBatchAllReduceDense(self, num_processes, required_gpus, communication,
|
||||
reduce_op, use_scoped_allocator):
|
||||
if required_gpus == 0 and communication == CollectiveCommunication.NCCL:
|
||||
def testBatchAllReduceDense(self, num_processes, required_gpus,
|
||||
implementation, reduce_op, use_scoped_allocator):
|
||||
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
if num_processes == 2 and communication == CollectiveCommunication.NCCL:
|
||||
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
|
||||
"physical GPUs for every process.")
|
||||
|
||||
@ -417,7 +418,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication=communication,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
use_scoped_allocator=use_scoped_allocator)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
@ -437,19 +439,19 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
communication=[
|
||||
CollectiveCommunication.AUTO,
|
||||
CollectiveCommunication.RING,
|
||||
CollectiveCommunication.NCCL,
|
||||
implementation=[
|
||||
CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING,
|
||||
CommunicationImplemenation.NCCL,
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
use_scoped_allocator=[True, False]))
|
||||
def testBatchAllReduceSparse(self, num_processes, required_gpus,
|
||||
communication, reduce_op, use_scoped_allocator):
|
||||
if required_gpus == 0 and communication == CollectiveCommunication.NCCL:
|
||||
implementation, reduce_op, use_scoped_allocator):
|
||||
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
if num_processes == 2 and communication == CollectiveCommunication.NCCL:
|
||||
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
|
||||
"physical GPUs for every process.")
|
||||
|
||||
@ -458,7 +460,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication=communication,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
use_scoped_allocator=use_scoped_allocator)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
@ -522,24 +525,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
required_gpus=[0, 1, 2],
|
||||
axis=[0, 1, 2],
|
||||
func_mode=["eager", "func_graph"],
|
||||
communication=[
|
||||
CollectiveCommunication.NCCL,
|
||||
CollectiveCommunication.AUTO,
|
||||
CollectiveCommunication.RING
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING
|
||||
]))
|
||||
def testAllGatherSameShape(self, num_processes, required_gpus, communication,
|
||||
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
|
||||
func_mode, axis):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus,
|
||||
communication)
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32)
|
||||
|
||||
def gather_fn():
|
||||
per_replica_value = make_per_replica_value(value, devices)
|
||||
gathered_values = collective._gather(
|
||||
per_replica_value, per_replica_value, axis=axis)
|
||||
per_replica_value, per_replica_value, axis=axis, options=options)
|
||||
gathered_values = self.as_list(gathered_values)
|
||||
# Skip checking devices in eager. In eager the device attribute doesn't
|
||||
# reflect the actual device of the tensor.
|
||||
@ -565,17 +567,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=1,
|
||||
required_gpus=2,
|
||||
communication=[
|
||||
CollectiveCommunication.NCCL, CollectiveCommunication.RING
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
|
||||
]))
|
||||
def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
|
||||
required_gpus,
|
||||
communication):
|
||||
implementation):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus,
|
||||
communication)
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
|
||||
# We would like to simulate the following sequence:
|
||||
# thread-0 device0 device1
|
||||
@ -604,14 +606,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def thread_fn():
|
||||
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
|
||||
[(v0, v0), (v0, v0)])
|
||||
[(v0, v0), (v0, v0)], options)
|
||||
self.assertAllEqual(reduced[0].values, [2.0, 2.0])
|
||||
self.assertAllEqual(reduced[1].values, [2.0, 2.0])
|
||||
|
||||
t = threading.Thread(target=thread_fn)
|
||||
t.start()
|
||||
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
|
||||
(v1, v1)])
|
||||
(v1, v1)],
|
||||
options)
|
||||
self.assertAllEqual(reduced[0].values, [4.0, 4.0])
|
||||
self.assertAllEqual(reduced[1].values, [4.0, 4.0])
|
||||
t.join()
|
||||
@ -622,16 +625,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=1,
|
||||
required_gpus=2,
|
||||
communication=[
|
||||
CollectiveCommunication.NCCL, CollectiveCommunication.RING
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
|
||||
]))
|
||||
def testInputsAreFunctionArgs(self, num_processes, required_gpus,
|
||||
communication):
|
||||
implementation):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus,
|
||||
communication)
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
|
||||
@def_function.function
|
||||
def reduce_fn(v):
|
||||
@ -641,7 +644,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# We only use NCCL for batch reduce with two or more values, so we use
|
||||
# two values here.
|
||||
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v),
|
||||
(v, v)])
|
||||
(v, v)],
|
||||
options)
|
||||
self.assertEqual(reduced[0].values[0].device, devices[0])
|
||||
self.assertEqual(reduced[0].values[1].device, devices[1])
|
||||
self.assertEqual(reduced[1].values[0].device, devices[0])
|
||||
@ -660,21 +664,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
communication=[CollectiveCommunication.RING]))
|
||||
def testTimeoutReduceDense(self, num_processes, communication, required_gpus):
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testTimeoutReduceDense(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus, communication)
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
return
|
||||
|
||||
v = make_per_replica_value(1.0, devices)
|
||||
hints = collective_util.Hints(timeout_seconds=1)
|
||||
options = collective_util.Options(
|
||||
timeout_seconds=1, implementation=implementation)
|
||||
|
||||
@def_function.function
|
||||
def reduce_dense():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're three workers in total.
|
||||
@ -687,23 +693,24 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
communication=[CollectiveCommunication.RING]))
|
||||
def testTimeoutBatchReduceDense(self, num_processes, communication,
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testTimeoutBatchReduceDense(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus, communication)
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
return
|
||||
|
||||
v = make_per_replica_value(1.0, devices)
|
||||
hints = collective_util.Hints(timeout_seconds=1)
|
||||
options = collective_util.Options(
|
||||
timeout_seconds=1, implementation=implementation)
|
||||
|
||||
@def_function.function
|
||||
def batch_reduce_dense():
|
||||
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
|
||||
hints)
|
||||
options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
@ -716,24 +723,25 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
communication=[CollectiveCommunication.RING]))
|
||||
def testTimeoutReduceSparse(self, num_processes, communication,
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testTimeoutReduceSparse(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus, communication)
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
return
|
||||
|
||||
v = make_per_replica_value(
|
||||
IndexedSlicesValue(
|
||||
values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
|
||||
hints = collective_util.Hints(timeout_seconds=1)
|
||||
options = collective_util.Options(
|
||||
timeout_seconds=1, implementation=implementation)
|
||||
|
||||
@def_function.function
|
||||
def reduce_sparse():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
@ -746,25 +754,26 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
communication=[CollectiveCommunication.RING]))
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
|
||||
communication):
|
||||
implementation):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus, communication)
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
return
|
||||
|
||||
v = make_per_replica_value(
|
||||
IndexedSlicesValue(
|
||||
values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
|
||||
hints = collective_util.Hints(timeout_seconds=1)
|
||||
options = collective_util.Options(
|
||||
timeout_seconds=1, implementation=implementation)
|
||||
|
||||
@def_function.function
|
||||
def batch_reduce_sparse():
|
||||
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
|
||||
hints)
|
||||
options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
|
@ -2182,7 +2182,7 @@ class StrategyExtendedV2(object):
|
||||
dst = device_util.current() or self._default_device or "/device:CPU:0"
|
||||
return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
|
||||
|
||||
def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
|
||||
def reduce_to(self, reduce_op, value, destinations, options=None):
|
||||
"""Combine (via e.g. sum or mean) values across replicas.
|
||||
|
||||
`reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
|
||||
@ -2247,12 +2247,17 @@ class StrategyExtendedV2(object):
|
||||
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
||||
to the devices of that variable, and this method doesn't update the
|
||||
variable.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
`tf.distribute.Strategy` takes one in the constructor. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details of the
|
||||
options.
|
||||
|
||||
Returns:
|
||||
A tensor or value reduced to `destinations`.
|
||||
"""
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
_require_cross_replica_or_default_context_extended(self)
|
||||
assert not isinstance(destinations, (list, tuple))
|
||||
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
|
||||
@ -2260,17 +2265,12 @@ class StrategyExtendedV2(object):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
assert (reduce_op == reduce_util.ReduceOp.SUM or
|
||||
reduce_op == reduce_util.ReduceOp.MEAN)
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self._reduce_to(reduce_op, value, destinations, experimental_hints)
|
||||
return self._reduce_to(reduce_op, value, destinations, options)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def batch_reduce_to(self,
|
||||
reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints=None):
|
||||
def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
|
||||
"""Combine multiple `reduce_to` calls into one for faster execution.
|
||||
|
||||
Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
|
||||
@ -2325,30 +2325,30 @@ class StrategyExtendedV2(object):
|
||||
"SUM", "MEAN".
|
||||
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
||||
`tf.distribute.Strategy.reduce_to` for descriptions.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
`tf.distribute.Strategy` takes one in the constructor. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details of the
|
||||
options.
|
||||
|
||||
Returns:
|
||||
A list of reduced values, one per pair in `value_destination_pairs`.
|
||||
"""
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
_require_cross_replica_or_default_context_extended(self)
|
||||
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
|
||||
if isinstance(reduce_op, six.string_types):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self._batch_reduce_to(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
|
||||
return [
|
||||
self.reduce_to(
|
||||
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
|
||||
self.reduce_to(reduce_op, t, destinations=v, options=options)
|
||||
for t, v in value_destination_pairs
|
||||
]
|
||||
|
||||
def _gather_to(self, value, destinations, axis, experimental_hints=None):
|
||||
def _gather_to(self, value, destinations, axis, options=None):
|
||||
"""Gather `value` across replicas along axis-th dimension to `destinations`.
|
||||
|
||||
`gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
||||
@ -2365,31 +2365,30 @@ class StrategyExtendedV2(object):
|
||||
variable.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||
range [0, rank(value)).
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
|
||||
`tf.distribute.experimental.CollectiveHints` for details.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
`tf.distribute.Strategy` takes one in the constructor. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details of the
|
||||
options.
|
||||
|
||||
Returns:
|
||||
A tensor or value gathered to `destinations`.
|
||||
"""
|
||||
_require_cross_replica_or_default_context_extended(self)
|
||||
assert not isinstance(destinations, (list, tuple))
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self._gather_to_implementation(value, destinations, axis, experimental_hints)
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
return self._gather_to_implementation(value, destinations, axis, options)
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
raise NotImplementedError("_gather_to must be implemented in descendants")
|
||||
|
||||
def _batch_gather_to(self,
|
||||
value_destination_pairs,
|
||||
axis,
|
||||
experimental_hints=None):
|
||||
def _batch_gather_to(self, value_destination_pairs, axis, options=None):
|
||||
_require_cross_replica_or_default_context_extended(self)
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
return [
|
||||
self._gather_to(
|
||||
t, destinations=v, axis=axis, experimental_hints=experimental_hints)
|
||||
self._gather_to(t, destinations=v, axis=axis, options=options)
|
||||
for t, v in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -2410,7 +2409,8 @@ class StrategyExtendedV2(object):
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices
|
||||
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
|
||||
devices
|
||||
with strategy.scope():
|
||||
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
|
||||
def update_fn(v):
|
||||
@ -2975,7 +2975,7 @@ class ReplicaContext(object):
|
||||
require_replica_context(self)
|
||||
return (device_util.current(),)
|
||||
|
||||
def all_reduce(self, reduce_op, value, experimental_hints=None):
|
||||
def all_reduce(self, reduce_op, value, options=None):
|
||||
"""All-reduces `value` across all replicas.
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
@ -2988,7 +2988,7 @@ class ReplicaContext(object):
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
|
||||
|
||||
It supports batched operations. You can pass a list of values and it
|
||||
attempts to batch them when possible. You can also specify `experimental_hints`
|
||||
attempts to batch them when possible. You can also specify `options`
|
||||
to indicate the desired batching behavior, e.g. batch the values into
|
||||
multiple packs so that they can better overlap with computations.
|
||||
|
||||
@ -3028,8 +3028,11 @@ class ReplicaContext(object):
|
||||
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
|
||||
The structure and the shapes of the `tf.Tensor` need to be same on all
|
||||
replicas.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
`tf.distribute.Strategy` takes one in the constructor. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details of the
|
||||
options.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.Tensor` with the reduced values. The structure
|
||||
@ -3037,13 +3040,13 @@ class ReplicaContext(object):
|
||||
"""
|
||||
if isinstance(reduce_op, six.string_types):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
|
||||
def batch_all_reduce(strategy, *value_flat):
|
||||
return strategy.extended.batch_reduce_to(
|
||||
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
|
||||
experimental_hints)
|
||||
options)
|
||||
|
||||
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
|
||||
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
|
||||
@ -3070,7 +3073,7 @@ class ReplicaContext(object):
|
||||
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
||||
|
||||
# TODO(wxinyi): generate docs after it is implemented for all strategies.
|
||||
def _all_gather(self, value, axis, experimental_hints=None):
|
||||
def _all_gather(self, value, axis, options=None):
|
||||
"""All-gathers `value` across all replicas along `axis`.
|
||||
|
||||
Note: An `all_gather` method can only be called in replica context. To find
|
||||
@ -3147,8 +3150,11 @@ class ReplicaContext(object):
|
||||
constructs can only be dense tensors with non-zero rank, NOT
|
||||
`tf.IndexedSlices`.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather.
|
||||
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
`tf.distribute.Strategy` takes one in the constructor. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details of the
|
||||
options.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.Tensor` with the gathered values. The structure
|
||||
@ -3159,13 +3165,13 @@ class ReplicaContext(object):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
"IndexedSlices")
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
|
||||
def batch_all_gather(strategy, *value_flat):
|
||||
return strategy.extended._batch_gather_to( # pylint: disable=protected-access
|
||||
[(v, _batch_reduce_destination(v)) for v in value_flat], axis,
|
||||
experimental_hints)
|
||||
options)
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def grad_wrapper(*xs):
|
||||
@ -3319,13 +3325,13 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
# TODO(josh11b): Use destinations?
|
||||
del reduce_op, destinations, experimental_hints
|
||||
del reduce_op, destinations, options
|
||||
return value
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
|
||||
del destinations, axis, experimental_hints
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
del destinations, axis, options
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
|
@ -94,8 +94,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _local_results(self, value):
|
||||
return (value,)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
del reduce_op, destinations, options
|
||||
return value
|
||||
|
||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -313,6 +314,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
assert devices, ("Got an empty `devices` list and unable to recognize "
|
||||
"any local devices.")
|
||||
self._cross_device_ops = cross_device_ops
|
||||
self._communication_options = collective_util.Options()
|
||||
self._initialize_strategy(devices)
|
||||
|
||||
# TODO(b/128995245): Enable last partial batch support in graph mode.
|
||||
@ -632,8 +634,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
del value # Unused.
|
||||
return self._cross_device_ops or self._inferred_cross_device_ops
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
if not isinstance(value, values.DistributedValues):
|
||||
# ReductionToOneDevice._gather accepts DistributedValues only.
|
||||
return value
|
||||
@ -641,9 +642,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
value,
|
||||
destinations=destinations,
|
||||
axis=axis,
|
||||
experimental_hints=experimental_hints)
|
||||
options=self._communication_options.merge(options))
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
if (distribute_utils.is_mirrored(value) and
|
||||
reduce_op == reduce_util.ReduceOp.MEAN):
|
||||
return value
|
||||
@ -659,10 +660,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
options=self._communication_options.merge(options))
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
|
||||
cross_device_ops = None
|
||||
for value, _ in value_destination_pairs:
|
||||
if cross_device_ops is None:
|
||||
@ -670,8 +670,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
elif cross_device_ops is not self._get_cross_device_ops(value):
|
||||
raise ValueError("inputs to batch_reduce_to must be either all on the "
|
||||
"the host or all on the compute devices")
|
||||
return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
return cross_device_ops.batch_reduce(
|
||||
reduce_op,
|
||||
value_destination_pairs,
|
||||
options=self._communication_options.merge(options))
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
|
@ -379,13 +379,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
with ops.device(self._device), _OneDeviceReplicaContext(strategy):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
del reduce_op, destinations, options
|
||||
return value
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
del destinations, axis, experimental_hints
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
del destinations, axis, options
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
|
@ -504,7 +504,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
(d, self._worker_device))
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
options):
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
if not isinstance(value, values.DistributedValues):
|
||||
return value
|
||||
@ -512,27 +512,22 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
value,
|
||||
destinations=destinations,
|
||||
axis=axis,
|
||||
experimental_hints=experimental_hints)
|
||||
options=options)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
if not isinstance(value, values.DistributedValues):
|
||||
# pylint: disable=protected-access
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
return self._cross_device_ops.reduce(
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
reduce_op, value, destinations=destinations, options=options)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
|
||||
for _, destinations in value_destination_pairs:
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
return self._cross_device_ops.batch_reduce(reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints)
|
||||
value_destination_pairs, options)
|
||||
|
||||
def _select_single_value(self, structured):
|
||||
"""Select any single value in `structured`."""
|
||||
|
@ -1022,8 +1022,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
|
||||
distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
|
||||
|
||||
def _gather_to_implementation(self, value, destinations, axis,
|
||||
experimental_hints):
|
||||
def _gather_to_implementation(self, value, destinations, axis, options):
|
||||
if not isinstance(value, values.DistributedValues):
|
||||
return value
|
||||
|
||||
@ -1070,7 +1069,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
return output
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||
if (isinstance(value, values.DistributedValues) or
|
||||
tensor_util.is_tensor(value)
|
||||
) and tpu_values.enclosing_tpu_context() is not None:
|
||||
@ -1412,8 +1411,8 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
return self.strategy.extended.experimental_logical_device(logical_device_id)
|
||||
|
||||
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
|
||||
def _all_gather(self, value, axis, experimental_hints=None):
|
||||
del experimental_hints
|
||||
def _all_gather(self, value, axis, options=None):
|
||||
del options
|
||||
for v in nest.flatten(value):
|
||||
if isinstance(v, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
|
@ -432,7 +432,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
|
||||
NUM_WORKERS = 3
|
||||
|
||||
CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
|
||||
CollectiveCommunication = collective_util.CollectiveCommunication
|
||||
|
||||
|
||||
class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
@ -477,8 +477,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=devices,
|
||||
group_size=len(devices),
|
||||
collective_keys=collective_keys,
|
||||
communication=communication)
|
||||
collective_keys=collective_keys)
|
||||
return collective_all_reduce_ops, devices, ""
|
||||
else:
|
||||
# NCCL requires physical GPUs for every replica, which we can't do with
|
||||
@ -509,8 +508,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=devices,
|
||||
group_size=len(devices) * NUM_WORKERS,
|
||||
collective_keys=collective_keys,
|
||||
communication=communication)
|
||||
collective_keys=collective_keys)
|
||||
return (collective_all_reduce_ops, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
|
||||
|
@ -8,11 +8,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -24,10 +24,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -9,11 +9,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -25,10 +25,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
|
@ -37,7 +37,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast_to"
|
||||
@ -69,7 +69,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
|
@ -1,16 +1,16 @@
|
||||
path: "tensorflow.distribute.experimental.CollectiveCommunication"
|
||||
tf_class {
|
||||
is_instance: "<enum \'CollectiveCommunication\'>"
|
||||
is_instance: "<enum \'CommunicationImplemenation\'>"
|
||||
member {
|
||||
name: "AUTO"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "NCCL"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "RING"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,16 @@
|
||||
path: "tensorflow.distribute.experimental.CommunicationImplemenation"
|
||||
tf_class {
|
||||
is_instance: "<enum \'CommunicationImplemenation\'>"
|
||||
member {
|
||||
name: "AUTO"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "NCCL"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "RING"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.experimental.CommunicationOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_util._OptionsExported\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplemenation.AUTO\'], "
|
||||
}
|
||||
}
|
@ -18,7 +18,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplemenation.AUTO\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
|
@ -12,6 +12,14 @@ tf_module {
|
||||
name: "CollectiveHints"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "CommunicationImplemenation"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "CommunicationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -8,11 +8,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -24,10 +24,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -9,11 +9,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -25,10 +25,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
|
@ -20,7 +20,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
@ -28,7 +28,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
|
@ -1,16 +1,16 @@
|
||||
path: "tensorflow.distribute.experimental.CollectiveCommunication"
|
||||
tf_class {
|
||||
is_instance: "<enum \'CollectiveCommunication\'>"
|
||||
is_instance: "<enum \'CommunicationImplemenation\'>"
|
||||
member {
|
||||
name: "AUTO"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "NCCL"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "RING"
|
||||
mtype: "<enum \'CollectiveCommunication\'>"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,16 @@
|
||||
path: "tensorflow.distribute.experimental.CommunicationImplemenation"
|
||||
tf_class {
|
||||
is_instance: "<enum \'CommunicationImplemenation\'>"
|
||||
member {
|
||||
name: "AUTO"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "NCCL"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
member {
|
||||
name: "RING"
|
||||
mtype: "<enum \'CommunicationImplemenation\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.experimental.CommunicationOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_util._OptionsExported\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplemenation.AUTO\'], "
|
||||
}
|
||||
}
|
@ -18,7 +18,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplemenation.AUTO\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
|
@ -12,6 +12,14 @@ tf_module {
|
||||
name: "CollectiveHints"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "CommunicationImplemenation"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "CommunicationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user