Graduate experimental_hints to options in all_reduce/reduce/batch_reduce

The CollectiveHints class is also renamed to CommunicationOptions. The communication enum is added to it.

CommunicationOptions stays experimental since the detailed options may change, but it's rather clear we need an options argument for these cross device communications.

PiperOrigin-RevId: 337547832
Change-Id: I376171672698d5923b4e52f2567d4a584c8e21b6
This commit is contained in:
Ran Chen 2020-10-16 11:43:26 -07:00 committed by TensorFlower Gardener
parent 9ef7492f43
commit f196a243ea
36 changed files with 633 additions and 384 deletions

View File

@ -54,6 +54,13 @@
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead.
* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`,
`tf.distribute.StrategyExtended.batch_reduce_to`,
`tf.distribute.ReplicaContext.all_reduce` are renamed to `options`.
`tf.distribute.experimental.CollectiveHints` is renamed
`tf.distribute.experimental.CommunicationOptions`.
`tf.distribute.experimental.CollectiveCommunication` is renamed
`tf.distribute.experimental.CommunicationImplementation`.
## Known Caveats

View File

@ -337,6 +337,7 @@ py_library(
name = "mirrored_strategy",
srcs = ["mirrored_strategy.py"],
deps = [
":collective_util",
":cross_device_ops",
":device_util",
":distribute_lib",
@ -425,6 +426,7 @@ py_library(
srcs = ["collective_all_reduce_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":collective_util",
":cross_device_ops",
":cross_device_utils",
":input_lib",
@ -669,12 +671,22 @@ py_library(
py_library(
name = "collective_util",
srcs = ["collective_util.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],
)
tf_py_test(
name = "collective_util_test",
srcs = ["collective_util_test.py"],
deps = [
":collective_util",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "shared_variable_creator",
srcs = ["shared_variable_creator.py"],

View File

@ -98,33 +98,34 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
Tensorflow API.
"""
# TODO(anjalisridhar): Update our guides with examples showing how we can use
# the cluster_resolver argument.
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__(
self,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
cluster_resolver=None):
def __init__(self,
communication=collective_util.CommunicationImplemenation.AUTO,
cluster_resolver=None):
"""Creates the strategy.
Args:
communication: optional
`tf.distribute.experimental.CollectiveCommunication`. This is a hint on
the preferred collective communication implementation. Possible values
include `AUTO`, `RING`, and `NCCL`.
`tf.distribute.experimental.CommunicationImplemenation`. This is a hint
on the preferred collective communication implementation. Possible
values include `AUTO`, `RING`, and `NCCL`.
cluster_resolver: optional
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
"""
# TODO(b/150151677): consider move communication to CollectiveHints.
communication_options = collective_util.Options(
implementation=communication)
super(CollectiveAllReduceStrategy, self).__init__(
CollectiveAllReduceExtended(
self,
communication=communication,
cluster_resolver=cluster_resolver))
cluster_resolver=cluster_resolver,
communication_options=communication_options))
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"MultiWorkerMirroredStrategy")
@ -138,7 +139,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
def _from_local_devices(
cls,
devices,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO):
communication=collective_util.CommunicationImplemenation.AUTO):
"""A convenience method to create an object with a list of devices."""
obj = cls(communication)
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
@ -162,16 +163,17 @@ class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
__doc__ = CollectiveAllReduceStrategy.__doc__
def __init__(
self,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
cluster_resolver=None):
def __init__(self,
communication=collective_util.CommunicationImplemenation.AUTO,
cluster_resolver=None):
"""Initializes the object."""
communication_options = collective_util.Options(
implementation=communication)
super(CollectiveAllReduceStrategyV1, self).__init__(
CollectiveAllReduceExtended(
self,
communication=communication,
cluster_resolver=cluster_resolver))
cluster_resolver=cluster_resolver,
communication_options=communication_options))
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
"MultiWorkerMirroredStrategy")
# pylint: disable=protected-access
@ -196,16 +198,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
# Times to retry before considering the peer is down.
_check_health_retry_limit = 3
def __init__(self,
container_strategy,
communication,
cluster_resolver):
def __init__(self, container_strategy, cluster_resolver,
communication_options):
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
assert isinstance(
communication,
cross_device_ops_lib.CollectiveCommunication)
self._communication = communication
self._communication_options = communication_options
self._initialize_strategy(self._cluster_resolver)
self._cfer_fn_cache = weakref.WeakKeyDictionary()
self.experimental_enable_get_next_as_optional = True
@ -255,15 +252,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices,
group_size=len(local_devices),
collective_keys=self._collective_keys,
communication=self._communication)
collective_keys=self._collective_keys)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
collective_keys=self._collective_keys,
communication=cross_device_ops_lib.CollectiveCommunication.RING,
)
collective_keys=self._collective_keys)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
@ -283,8 +277,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._rpc_layer = cluster_resolver.rpc_layer
self._warn_nccl_no_gpu()
logging.info("Single-worker MultiWorkerMirroredStrategy with local_devices "
"= %r, communication = %s", local_devices, self._communication)
logging.info(
"Single-worker MultiWorkerMirroredStrategy with local_devices "
"= %r, communication = %s", local_devices,
self._communication_options.implementation)
def _initialize_multi_worker(self, cluster_resolver):
"""Initializes the object for multi-worker training."""
@ -371,15 +367,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices,
group_size=len(local_devices) * self._num_workers,
collective_keys=self._collective_keys,
communication=self._communication)
collective_keys=self._collective_keys)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
collective_keys=self._collective_keys,
communication=cross_device_ops_lib.CollectiveCommunication.RING,
)
collective_keys=self._collective_keys)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
@ -398,9 +391,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
logging.info(
"MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
"task_id = %r, num_workers = %r, local_devices = %r, "
"communication = %s", cluster_spec.as_dict(), task_type,
task_id, self._num_workers, local_devices,
self._communication)
"communication = %s", cluster_spec.as_dict(), task_type, task_id,
self._num_workers, local_devices,
self._communication_options.implementation)
def __del__(self):
self._stop_check_health_thread()
@ -571,8 +564,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
if (not ops.executing_eagerly_outside_functions() and
self._communication ==
cross_device_ops_lib.CollectiveCommunication.NCCL):
self._communication_options.implementation ==
collective_util.CommunicationImplemenation.NCCL):
updated_config.experimental.collective_nccl = True
if not self._cluster_spec:
@ -610,15 +603,14 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
return self._host_cross_device_ops
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
def _gather_to_implementation(self, value, destinations, axis, options):
return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access
value,
destinations=destinations,
axis=axis,
experimental_hints=experimental_hints)
options=options)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
@ -642,7 +634,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
options=self._communication_options.merge(options))
def _check_health(self):
while True:
@ -704,8 +696,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
reduce_util.ReduceOp.SUM,
dummy_value,
dummy_value,
experimental_hints=collective_util.Hints(
timeout_seconds=self._check_health_initial_timeout))
options=collective_util.Options(
timeout_seconds=self._check_health_initial_timeout,
implementation=collective_util.CommunicationImplemenation.RING))
if context.is_async():
context.async_wait()
except errors.DeadlineExceededError:
@ -731,8 +724,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
logging.info("check health thread stopped")
def _warn_nccl_no_gpu(self):
if ((self._communication ==
cross_device_ops_lib.CollectiveCommunication.NCCL) and
if ((self._communication_options.implementation ==
collective_util.CommunicationImplemenation.NCCL) and
self._num_gpus_per_worker == 0):
logging.warning("Enabled NCCL communication but no GPUs detected/"
"specified.")

View File

@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -18,9 +19,143 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import enum
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# TODO(b/170340570): print deprecation warning for CollectiveCommunication.
@tf_export("distribute.experimental.CommunicationImplemenation",
"distribute.experimental.CollectiveCommunication")
class CommunicationImplemenation(enum.Enum):
"""Cross device communication implementation.
Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is
deprecated and will be removed in a future version. Use
`tf.distribute.experimental.CommunicationImplemenation` instead.
* `AUTO`: Automatically chosen by Tensorflow.
* `RING`: TensorFlow's ring algorithms for all-reduce and
all-gather.
* `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on
GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING.
"""
AUTO = "AUTO"
RING = "RING"
NCCL = "NCCL"
# TODO(ayushd): add ncclAllGather implementation.
CollectiveCommunication = CommunicationImplemenation
@tf_export("distribute.experimental.CommunicationOptions")
class _OptionsExported(object):
"""Options for cross device communications like All-reduce.
This can be passed to methods like
`tf.distribute.get_replica_context().all_reduce()` to optimize collective
operation performance. Note that these are only hints, which may or may not
change the actual behavior. Some options only apply to certain strategy and
are ignored by others.
One common optimization is to break gradients all-reduce into multiple packs
so that weight updates can overlap with gradient all-reduce.
Examples:
```python
options = tf.distribute.experimental.CommunicationOptions(
bytes_per_pack=50 * 1024 * 1024,
timeout_seconds=120,
implementation=tf.distribute.experimental.CommunicationImplemenation.NCCL)
grads = tf.distribute.get_replica_context().all_reduce(
'sum', grads, options=options)
optimizer.apply_gradients(zip(grads, vars),
experimental_aggregate_gradients=False)
```
"""
def __new__(cls, *args, **kwargs):
return Options.__new__(Options, *args, **kwargs)
def __init__(self,
bytes_per_pack=0,
timeout_seconds=None,
implementation=CommunicationImplemenation.AUTO):
"""Creates a CollectiveHints.
Args:
bytes_per_pack: a non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, the value is determined
automatically. This only applies to all-reduce with
`MultiWorkerMirroredStrategy` currently.
timeout_seconds: a float or None, timeout in seconds. If not None, the
collective raises `tf.errors.DeadlineExceededError` if it takes longer
than this timeout. Zero disables timeout. This can be useful when
debugging hanging issues. This should only be used for debugging since
it creates a new thread for each collective, i.e. an overhead of
`timeout_seconds * num_collectives_per_second` more threads. This only
works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
implementation: a `tf.distribute.experimental.CommunicationImplemenation`.
This is a hint on the preferred communication implementation. Possible
values include `AUTO`, `RING`, and `NCCL`. NCCL is generally more
performant for GPU, but doesn't work for CPU. This only works for
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
Raises:
ValueError: When arguments have invalid value.
"""
pass
class Options(object):
"""Implementation of OptionsInterface."""
def __init__(self,
bytes_per_pack=0,
timeout_seconds=None,
implementation=CommunicationImplemenation.AUTO):
if bytes_per_pack < 0:
raise ValueError("bytes_per_pack must be non-negative")
if isinstance(implementation, str):
implementation = CommunicationImplemenation(implementation.upper())
if not isinstance(implementation, CommunicationImplemenation):
raise ValueError("implementation should be a "
"tf.distribute.experimental.CommunicationImplemenation")
self.bytes_per_pack = bytes_per_pack
self.timeout_seconds = timeout_seconds
self.implementation = implementation
__init__.__doc__ = _OptionsExported.__init__.__doc__
def merge(self, options):
"""Merges with another options and returns a new one.
Values specified in the `options` takes precedence if they're not the
default.
Args:
options: a `tf.distribute.experimental.CollectiveCommunication`.
Returns:
A new `tf.distribute.experimental.CollectiveCommunication`.
"""
merged = copy.deepcopy(self)
if options is None:
return merged
if options.bytes_per_pack != 0:
merged.bytes_per_pack = options.bytes_per_pack
if options.timeout_seconds is not None:
merged.timeout_seconds = options.timeout_seconds
if options.implementation != CommunicationImplemenation.AUTO:
merged.implementation = options.implementation
return merged
@tf_export("distribute.experimental.CollectiveHints")
class Hints(object):
"""Hints for collective operations like AllReduce.
@ -61,6 +196,12 @@ class Hints(object):
"""
@deprecation.deprecated(
None, "use distribute.experimental.CommunicationOptions instead")
def __new__(cls, bytes_per_pack=0, timeout_seconds=None):
return Options(
bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds)
def __init__(self, bytes_per_pack=0, timeout_seconds=None):
"""Creates a CollectiveHints.
@ -80,7 +221,4 @@ class Hints(object):
Raises:
ValueError: When arguments have invalid value.
"""
if bytes_per_pack < 0:
raise ValueError("bytes_per_pack must be non-negative")
self.bytes_per_pack = bytes_per_pack
self.timeout_seconds = timeout_seconds
pass

View File

@ -0,0 +1,41 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for utilities for collectives."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import collective_util
from tensorflow.python.eager import test
class OptionsTest(test.TestCase):
def testCreateOptionsViaExportedAPI(self):
options = collective_util._OptionsExported()
self.assertIsInstance(options, collective_util.Options)
def testCreateOptionsViaHints(self):
with self.assertLogs() as cm:
options = collective_util.Hints(50, 1)
self.assertTrue(any("is deprecated" in msg for msg in cm.output))
self.assertIsInstance(options, collective_util.Options)
self.assertEqual(options.bytes_per_pack, 50)
self.assertEqual(options.timeout_seconds, 1)
if __name__ == "__main__":
test.main()

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
import enum
import threading
import six
@ -251,11 +250,7 @@ class CrossDeviceOps(object):
# Returns 1 by default, the value may be overridden by sub classes.
return 1
def reduce(self,
reduce_op,
per_replica_value,
destinations,
experimental_hints=None):
def reduce(self, reduce_op, per_replica_value, destinations, options=None):
"""Reduce `per_replica_value` to `destinations`.
See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
@ -272,8 +267,8 @@ class CrossDeviceOps(object):
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, and this method doesn't update the
variable.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
@ -283,6 +278,8 @@ class CrossDeviceOps(object):
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
if options is None:
options = collective_util.Options()
if not isinstance(per_replica_value, value_lib.DistributedValues):
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
@ -296,16 +293,12 @@ class CrossDeviceOps(object):
v = array_ops.identity(per_replica_value.values[0])
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if options is None:
options = collective_util.Options()
return self.reduce_implementation(reduce_op, per_replica_value,
destinations, experimental_hints)
destinations, options)
def _gather(self,
per_replica_value,
destinations,
axis,
experimental_hints=None):
def _gather(self, per_replica_value, destinations, axis, options=None):
"""Gather `per_replica_value` to `destinations`.
Args:
@ -319,8 +312,8 @@ class CrossDeviceOps(object):
variable.
axis: specifies the dimension to gather along within each replica's
tensor.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`
@ -330,8 +323,11 @@ class CrossDeviceOps(object):
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if isinstance(per_replica_value, ops.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
if options is None:
options = collective_util.Options()
if not isinstance(per_replica_value, value_lib.DistributedValues):
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
@ -347,10 +343,10 @@ class CrossDeviceOps(object):
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
return self._gather_implementation(per_replica_value, destinations, axis,
experimental_hints)
options)
def _gather_implementation(self, per_replica_value, destinations, axis,
experimental_hints):
options):
"""Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
Overriding this method is useful for subclass implementers.
@ -366,8 +362,8 @@ class CrossDeviceOps(object):
variable.
axis: specifies the dimension to gather along within each replica's
tensor.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
@ -380,10 +376,7 @@ class CrossDeviceOps(object):
raise NotImplementedError(
"_gather method must be implemented in descendants.")
def batch_reduce(self,
reduce_op,
value_destination_pairs,
experimental_hints=None):
def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
"""Reduce values to destinations in batches.
See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
@ -394,8 +387,8 @@ class CrossDeviceOps(object):
combined.
value_destination_pairs: a sequence of (value, destinations) pairs. See
`tf.distribute.CrossDeviceOps.reduce` for descriptions.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
@ -405,6 +398,8 @@ class CrossDeviceOps(object):
ValueError: if `value_destination_pairs` is not an iterable of
tuples of `tf.distribute.DistributedValues` and destinations.
"""
if options is None:
options = collective_util.Options()
# TODO(yuefengz): if destinations are different, split into several
# `_batch_reduce` invocations.
if not _validate_value_destination_pairs(value_destination_pairs):
@ -425,10 +420,10 @@ class CrossDeviceOps(object):
for v, _ in value_destination_pairs
]
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if options is None:
options = collective_util.Options()
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
experimental_hints)
options)
def broadcast(self, tensor, destinations):
"""Broadcast `tensor` to `destinations`.
@ -451,7 +446,7 @@ class CrossDeviceOps(object):
@doc_controls.for_subclass_implementers
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
options):
"""Implementation of `reduce`.
Overriding this method is useful for subclass implementers.
@ -467,8 +462,8 @@ class CrossDeviceOps(object):
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, this method doesn't update the
variable.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
@ -483,7 +478,7 @@ class CrossDeviceOps(object):
@doc_controls.for_subclass_implementers
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
options):
"""Implementation of `batch_reduce`.
Overriding this method is useful for subclass implementers.
@ -493,8 +488,8 @@ class CrossDeviceOps(object):
combined.
value_destination_pairs: a sequence of (value, destinations) pairs. See
`reduce` for descriptions.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
to perform collective operations.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
@ -558,8 +553,8 @@ class ReductionToOneDevice(CrossDeviceOps):
super(ReductionToOneDevice, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
del experimental_hints # Unused.
options):
del options # Unused.
if check_destinations(destinations):
devices = get_devices_from(destinations)
else:
@ -573,8 +568,8 @@ class ReductionToOneDevice(CrossDeviceOps):
return self.broadcast(reduced, destinations)
def _gather_implementation(self, per_replica_value, destinations, axis,
experimental_hints):
del experimental_hints # Unused.
options):
del options # Unused.
if check_destinations(destinations):
devices = get_devices_from(destinations)
else:
@ -587,10 +582,10 @@ class ReductionToOneDevice(CrossDeviceOps):
return self.broadcast(gathered, destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
options):
return [
self.reduce_implementation(
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
reduce_op, t, destinations=v, options=options)
for t, v in value_destination_pairs
]
@ -806,8 +801,8 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
super(AllReduceCrossDeviceOps, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
del experimental_hints # Unused.
options):
del options # Unused.
if _devices_match(per_replica_value, destinations):
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
else:
@ -815,13 +810,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
options):
if _all_devices_match(value_destination_pairs):
return self._batch_all_reduce(reduce_op,
[v[0] for v in value_destination_pairs])
else:
return [
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
self.reduce_implementation(reduce_op, value, dest, options)
for value, dest in value_destination_pairs
]
@ -882,13 +877,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
reduce_op, zip(sparse_values, sparse_values))
def _gather_implementation(self, per_replica_value, destinations, axis,
experimental_hints):
options):
logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not "
"supported. Falling back to gather on one device and "
"then broadcast. We're working on a more efficient "
"implementation.")
return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access
experimental_hints)
options)
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
@ -979,20 +974,9 @@ class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
num_packs=num_packs)
@tf_export("distribute.experimental.CollectiveCommunication")
class CollectiveCommunication(enum.Enum):
"""Communication choices for CollectiveOps.
* `AUTO`: Default to runtime's automatic choices.
* `RING`: TensorFlow's ring algorithms for all-reduce and
all-gather.
* `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for
all-gather.
"""
AUTO = "AUTO"
RING = "RING"
NCCL = "NCCL"
# TODO(ayushd): add ncclAllGather implementation.
# TODO(crccw): remove after migrating all callers.
CollectiveCommunication = collective_util.CommunicationImplemenation
CommunicationImplemenation = collective_util.CommunicationImplemenation
# TODO(yuefengz): support in-graph collective all-reduce.
@ -1003,11 +987,7 @@ class CollectiveAllReduce(CrossDeviceOps):
all workers and then put results on the right destinations.
"""
def __init__(self,
devices,
group_size,
collective_keys=None,
communication=CollectiveCommunication.AUTO):
def __init__(self, devices, group_size, collective_keys=None):
"""Initializes the object.
Args:
@ -1015,7 +995,6 @@ class CollectiveAllReduce(CrossDeviceOps):
group_size: the global group size. For between-graph replicated training
it's the total number of devices across all workers.
collective_keys: an optional CollectiveKey object.
communication: indicates which collective communication to use.
"""
if group_size % len(devices) > 0:
raise ValueError("group_size must be divisible by the number of devices.")
@ -1023,7 +1002,6 @@ class CollectiveAllReduce(CrossDeviceOps):
self._group_size = group_size
self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys())
self._communication = communication
# This lock guards all collective launches, i.e. calls to
# cross_device_utils.build_collectve_*.
#
@ -1063,10 +1041,10 @@ class CollectiveAllReduce(CrossDeviceOps):
return self._group_size / len(self._devices)
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
options):
values_util.mark_as_unsaveable()
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
experimental_hints)[0]
options)[0]
devices = get_devices_from(destinations)
if _devices_match(per_replica_value, destinations):
@ -1095,13 +1073,13 @@ class CollectiveAllReduce(CrossDeviceOps):
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
options):
values_util.mark_as_unsaveable()
all_devices_match = _all_devices_match(value_destination_pairs)
if all_devices_match:
return self._batch_all_reduce(reduce_op,
[v[0] for v in value_destination_pairs],
experimental_hints)
options)
else:
if not all_devices_match:
logging.log_first_n(
@ -1109,43 +1087,41 @@ class CollectiveAllReduce(CrossDeviceOps):
"destinations are different.", 10)
return [
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
self.reduce_implementation(reduce_op, value, dest, options)
for value, dest in value_destination_pairs
]
def _batch_all_reduce(self, reduce_op, per_replica_values,
experimental_hints):
def _batch_all_reduce(self, reduce_op, per_replica_values, options):
"""All reduce algorithm in a batch."""
dense_values, dense_indices, sparse_values, sparse_indices = (
cross_device_utils.split_by_sparsity(per_replica_values))
if dense_values:
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
experimental_hints)
options)
else:
dense_results = []
if sparse_values:
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
sparse_values,
experimental_hints)
sparse_values, options)
else:
sparse_results = []
return cross_device_utils.stitch_values(
((dense_results, dense_indices), (sparse_results, sparse_indices)))
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
experimental_hints):
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options):
"""All-reduce across all workers in a batch."""
batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint.
communication = self._communication.value
implementation = options.implementation.value
# For now, we use NCCL only when batch_size > 1 since we don't have a way to
# order NCCL launches. We're hoping that there's only one batched
# all-reduce, which is the gradients.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL if and only if we can order collectives deterministically.
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
communication = CollectiveCommunication.AUTO.value
# is NCCL.
if (options.implementation == CommunicationImplemenation.NCCL and
batch_size == 1):
implementation = CommunicationImplemenation.AUTO.value
# Reverse the lists so that there's better chance that values follows
# the order in which they are calculated (e.g. when they're gradients), so
@ -1166,15 +1142,15 @@ class CollectiveAllReduce(CrossDeviceOps):
with self._lock:
for i in range(len(self._devices)):
packs = cross_device_utils.group_by_size(
values_by_device[i], experimental_hints.bytes_per_pack)
values_by_device[i], options.bytes_per_pack)
if not context.executing_eagerly() and i == 0:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, communication_hint = %s, num_packs = %d",
batch_size, len(self._launchers), self._group_size, communication,
len(packs))
"group_size = %d, implementation = %s, num_packs = %d",
batch_size, len(self._launchers), self._group_size,
implementation, len(packs))
outputs_by_device.append(self._launchers[i].batch_all_reduce(
packs, communication, experimental_hints.timeout_seconds))
packs, implementation, options.timeout_seconds))
for e in self._executors:
e.wait()
@ -1191,8 +1167,7 @@ class CollectiveAllReduce(CrossDeviceOps):
# Reverse the order of reduced value to recover the order in the input.
return list(reversed(mirrored))
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
experimental_hints):
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options):
"""All-reduce IndexedSlices across all workers in a batch."""
logging.log_first_n(
@ -1200,8 +1175,13 @@ class CollectiveAllReduce(CrossDeviceOps):
"%d all-reduces, group_size = %d" %
(len(per_replica_values), self._group_size), 10)
# Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value
implementation = options.implementation.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when implementation
# is NCCL.
if options.implementation == CommunicationImplemenation.NCCL and len(
per_replica_values) == 1:
implementation = CommunicationImplemenation.AUTO.value
gathered_values = []
with self._lock:
@ -1209,8 +1189,7 @@ class CollectiveAllReduce(CrossDeviceOps):
outputs = []
for i in range(len(self._devices)):
outputs.append(self._launchers[i].all_reduce_indexed_slices(
per_replica.values[i], communication_hint,
experimental_hints.timeout_seconds))
per_replica.values[i], implementation, options.timeout_seconds))
gathered_values.append(outputs)
mirrored = []
@ -1225,10 +1204,9 @@ class CollectiveAllReduce(CrossDeviceOps):
return mirrored
def _gather_implementation(self, per_replica_value, destinations, axis,
experimental_hints):
options):
all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
values_util.mark_as_unsaveable()
all_gathered = self._batch_all_gather([per_replica_value], axis,
experimental_hints)[0]
devices = get_devices_from(destinations)
if _devices_match(per_replica_value, destinations):
@ -1254,16 +1232,23 @@ class CollectiveAllReduce(CrossDeviceOps):
index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
def _batch_all_gather(self, per_replica_values, axis, experimental_hints):
def _batch_all_gather(self, per_replica_values, axis, options):
"""all gather multiple per-replica-values."""
batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint.
communication = self._communication.value
# Pass options.implementation to the runtime as a communication
# implementation hint.
implementation = options.implementation.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when implementation
# is NCCL.
if (options.implementation == CommunicationImplemenation.NCCL and
batch_size == 1):
implementation = CommunicationImplemenation.AUTO.value
logging.log_first_n(
logging.INFO, "Collective batch_all_gather: %d all-gathers, "
"num_devices = %d, group_size = %d, communication_hint = %s, " %
(batch_size, len(self._devices), self._group_size, communication), 10)
"num_devices = %d, group_size = %d, implementation = %s, " %
(batch_size, len(self._devices), self._group_size, implementation), 10)
def compute_gathered_values():
gathered_values = []
@ -1272,8 +1257,8 @@ class CollectiveAllReduce(CrossDeviceOps):
outputs = []
for i in range(len(self._devices)):
outputs.append(self._launchers[i].all_gather(
per_replica.values[i], axis, communication,
experimental_hints.timeout_seconds))
per_replica.values[i], axis, implementation,
options.timeout_seconds))
gathered_values.append(outputs)
return gathered_values
@ -1292,8 +1277,7 @@ class CollectiveAllReduce(CrossDeviceOps):
# distribute_coordinator deep-copies the strategy object, so
# CollectiveAllReduce needs to support deep copy as well.
collective_keys = copy.deepcopy(self._collective_keys, memo)
return CollectiveAllReduce(self._devices, self._group_size, collective_keys,
self._communication)
return CollectiveAllReduce(self._devices, self._group_size, collective_keys)
def select_cross_device_ops(devices, session_config=None):

View File

@ -48,7 +48,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.util import nest
CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
CommunicationImplemenation = collective_util.CommunicationImplemenation
ReduceOp = reduce_util.ReduceOp
IndexedSlicesValue = indexed_slices.IndexedSlicesValue
IndexedSlices = indexed_slices.IndexedSlices
@ -140,14 +140,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
global_mpr_1p.runner.run(enable_collective_ops)
global_mpr_2p.runner.run(enable_collective_ops)
def make_collective(self, num_processes, gpu_per_process, communication):
def make_collective(self, num_processes, gpu_per_process):
"""Returns collectives and other info to be used in tests.
Args:
num_processes: an integer indicating the number of processes that
participate in the collective.
gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
communication: one of `CollectiveCommunication`.
Returns:
A tuple of (collective, devices, group_size) where collective is a instance
@ -167,7 +166,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
]
group_size = num_processes * len(devices)
collective = cross_device_ops_lib.CollectiveAllReduce(
devices=devices, group_size=group_size, communication=communication)
devices=devices, group_size=group_size)
return collective, devices, cluster_resolver.task_id
def as_list(self, value):
@ -202,12 +201,12 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
"num_processes",
"gpus_per_process",
"reduce_op",
"communication",
"communication_options",
"use_scoped_allocator",
])
RunOptions.__new__.__defaults__ = (["eager",
"func_graph"], 2, 0, ReduceOp.SUM,
CollectiveCommunication.AUTO, True)
collective_util.Options(), True)
def reduce_and_verify(self, inputs, expect, options):
"""Reduce the given `inputs` and verify the output matches `expect`.
@ -222,14 +221,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
def replica_fn():
collective, devices, pid = self.make_collective(options.num_processes,
options.gpus_per_process,
options.communication)
options.gpus_per_process)
def reduce_fn():
value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx]
per_replica_value = make_per_replica_value(value_fn, devices)
reduced_values = collective.reduce(options.reduce_op, per_replica_value,
per_replica_value)
per_replica_value,
options.communication_options)
reduced_values = self.as_list(reduced_values)
self.assertAllEqual(devices, [v.device for v in reduced_values])
return [ops.convert_to_tensor(v) for v in reduced_values]
@ -261,8 +260,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = (
options.use_scoped_allocator)
collective, devices, pid = self.make_collective(options.num_processes,
options.gpus_per_process,
options.communication)
options.gpus_per_process)
def batch_reduce_fn():
batch_size = len(inputs[0])
@ -275,7 +273,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
per_replica_value = make_per_replica_value(value_fn, devices)
value_dst_pairs.append((per_replica_value, per_replica_value))
reduced_values = collective.batch_reduce(options.reduce_op,
value_dst_pairs)
value_dst_pairs,
options.communication_options)
reduced_values = [self.as_list(v) for v in reduced_values]
for v in reduced_values:
self.assertAllEqual(devices, [t.device for t in v])
@ -298,20 +297,21 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
communication=[
implementation=[
# NCCL is only used for batch reduce, so we are not including
# NCCL combination here.
CollectiveCommunication.AUTO,
CollectiveCommunication.RING
CommunicationImplemenation.AUTO,
CommunicationImplemenation.RING
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN]))
def testAllReduceDense(self, num_processes, required_gpus, communication,
def testAllReduceDense(self, num_processes, required_gpus, implementation,
reduce_op):
options = self.RunOptions(
num_processes=num_processes,
gpus_per_process=required_gpus,
reduce_op=reduce_op,
communication=communication)
communication_options=collective_util.Options(
implementation=implementation))
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = [1.0, 2.0, 3.0, 4.0]
@ -330,22 +330,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
communication=[
implementation=[
# NCCL is only used for batch reduce, so we are not including
# NCCL combination here.
CollectiveCommunication.AUTO,
CollectiveCommunication.RING
CommunicationImplemenation.AUTO,
CommunicationImplemenation.RING
],
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM))
def testAllReduceSparse(self, num_processes, required_gpus, communication,
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op):
options = self.RunOptions(
mode=["func_graph"], # Sparse reduce is not supported in eager.
num_processes=num_processes,
gpus_per_process=required_gpus,
reduce_op=reduce_op,
communication=communication)
communication_options=collective_util.Options(
implementation=implementation))
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = [
@ -399,17 +400,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
communication=[
CollectiveCommunication.AUTO, CollectiveCommunication.RING,
CollectiveCommunication.NCCL
implementation=[
CommunicationImplemenation.AUTO, CommunicationImplemenation.RING,
CommunicationImplemenation.NCCL
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
use_scoped_allocator=[True, False]))
def testBatchAllReduceDense(self, num_processes, required_gpus, communication,
reduce_op, use_scoped_allocator):
if required_gpus == 0 and communication == CollectiveCommunication.NCCL:
def testBatchAllReduceDense(self, num_processes, required_gpus,
implementation, reduce_op, use_scoped_allocator):
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
self.skipTest("Skip CPU + NCCL combination")
if num_processes == 2 and communication == CollectiveCommunication.NCCL:
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
@ -417,7 +418,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
num_processes=num_processes,
gpus_per_process=required_gpus,
reduce_op=reduce_op,
communication=communication,
communication_options=collective_util.Options(
implementation=implementation),
use_scoped_allocator=use_scoped_allocator)
group_size = options.num_processes * (options.gpus_per_process or 1)
@ -437,19 +439,19 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
communication=[
CollectiveCommunication.AUTO,
CollectiveCommunication.RING,
CollectiveCommunication.NCCL,
implementation=[
CommunicationImplemenation.AUTO,
CommunicationImplemenation.RING,
CommunicationImplemenation.NCCL,
],
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
use_scoped_allocator=[True, False]))
def testBatchAllReduceSparse(self, num_processes, required_gpus,
communication, reduce_op, use_scoped_allocator):
if required_gpus == 0 and communication == CollectiveCommunication.NCCL:
implementation, reduce_op, use_scoped_allocator):
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
self.skipTest("Skip CPU + NCCL combination")
if num_processes == 2 and communication == CollectiveCommunication.NCCL:
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
@ -458,7 +460,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
num_processes=num_processes,
gpus_per_process=required_gpus,
reduce_op=reduce_op,
communication=communication,
communication_options=collective_util.Options(
implementation=implementation),
use_scoped_allocator=use_scoped_allocator)
group_size = options.num_processes * (options.gpus_per_process or 1)
@ -522,24 +525,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
required_gpus=[0, 1, 2],
axis=[0, 1, 2],
func_mode=["eager", "func_graph"],
communication=[
CollectiveCommunication.NCCL,
CollectiveCommunication.AUTO,
CollectiveCommunication.RING
implementation=[
CommunicationImplemenation.NCCL, CommunicationImplemenation.AUTO,
CommunicationImplemenation.RING
]))
def testAllGatherSameShape(self, num_processes, required_gpus, communication,
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
func_mode, axis):
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus,
communication)
required_gpus)
options = collective_util.Options(implementation=implementation)
value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32)
def gather_fn():
per_replica_value = make_per_replica_value(value, devices)
gathered_values = collective._gather(
per_replica_value, per_replica_value, axis=axis)
per_replica_value, per_replica_value, axis=axis, options=options)
gathered_values = self.as_list(gathered_values)
# Skip checking devices in eager. In eager the device attribute doesn't
# reflect the actual device of the tensor.
@ -565,17 +567,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=1,
required_gpus=2,
communication=[
CollectiveCommunication.NCCL, CollectiveCommunication.RING
implementation=[
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
]))
def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
required_gpus,
communication):
implementation):
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus,
communication)
required_gpus)
options = collective_util.Options(implementation=implementation)
# We would like to simulate the following sequence:
# thread-0 device0 device1
@ -604,14 +606,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
def thread_fn():
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
[(v0, v0), (v0, v0)])
[(v0, v0), (v0, v0)], options)
self.assertAllEqual(reduced[0].values, [2.0, 2.0])
self.assertAllEqual(reduced[1].values, [2.0, 2.0])
t = threading.Thread(target=thread_fn)
t.start()
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
(v1, v1)])
(v1, v1)],
options)
self.assertAllEqual(reduced[0].values, [4.0, 4.0])
self.assertAllEqual(reduced[1].values, [4.0, 4.0])
t.join()
@ -622,16 +625,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=1,
required_gpus=2,
communication=[
CollectiveCommunication.NCCL, CollectiveCommunication.RING
implementation=[
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
]))
def testInputsAreFunctionArgs(self, num_processes, required_gpus,
communication):
implementation):
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus,
communication)
required_gpus)
options = collective_util.Options(implementation=implementation)
@def_function.function
def reduce_fn(v):
@ -641,7 +644,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
# We only use NCCL for batch reduce with two or more values, so we use
# two values here.
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v),
(v, v)])
(v, v)],
options)
self.assertEqual(reduced[0].values[0].device, devices[0])
self.assertEqual(reduced[0].values[1].device, devices[1])
self.assertEqual(reduced[1].values[0].device, devices[0])
@ -660,21 +664,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
communication=[CollectiveCommunication.RING]))
def testTimeoutReduceDense(self, num_processes, communication, required_gpus):
implementation=[CommunicationImplemenation.RING]))
def testTimeoutReduceDense(self, num_processes, implementation,
required_gpus):
def replica_fn():
collective, devices, task_id = self.make_collective(
num_processes, required_gpus, communication)
num_processes, required_gpus)
if task_id != 0:
return
v = make_per_replica_value(1.0, devices)
hints = collective_util.Hints(timeout_seconds=1)
options = collective_util.Options(
timeout_seconds=1, implementation=implementation)
@def_function.function
def reduce_dense():
collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
# The collective should time out because we only launch it on worker-0,
# while there're three workers in total.
@ -687,23 +693,24 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
communication=[CollectiveCommunication.RING]))
def testTimeoutBatchReduceDense(self, num_processes, communication,
implementation=[CommunicationImplemenation.RING]))
def testTimeoutBatchReduceDense(self, num_processes, implementation,
required_gpus):
def replica_fn():
collective, devices, task_id = self.make_collective(
num_processes, required_gpus, communication)
num_processes, required_gpus)
if task_id != 0:
return
v = make_per_replica_value(1.0, devices)
hints = collective_util.Hints(timeout_seconds=1)
options = collective_util.Options(
timeout_seconds=1, implementation=implementation)
@def_function.function
def batch_reduce_dense():
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
hints)
options)
# The collective should time out because we only launch it on worker-0,
# while there're two workers in total.
@ -716,24 +723,25 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
communication=[CollectiveCommunication.RING]))
def testTimeoutReduceSparse(self, num_processes, communication,
implementation=[CommunicationImplemenation.RING]))
def testTimeoutReduceSparse(self, num_processes, implementation,
required_gpus):
def replica_fn():
collective, devices, task_id = self.make_collective(
num_processes, required_gpus, communication)
num_processes, required_gpus)
if task_id != 0:
return
v = make_per_replica_value(
IndexedSlicesValue(
values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
hints = collective_util.Hints(timeout_seconds=1)
options = collective_util.Options(
timeout_seconds=1, implementation=implementation)
@def_function.function
def reduce_sparse():
collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
# The collective should time out because we only launch it on worker-0,
# while there're two workers in total.
@ -746,25 +754,26 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
communication=[CollectiveCommunication.RING]))
implementation=[CommunicationImplemenation.RING]))
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
communication):
implementation):
def replica_fn():
collective, devices, task_id = self.make_collective(
num_processes, required_gpus, communication)
num_processes, required_gpus)
if task_id != 0:
return
v = make_per_replica_value(
IndexedSlicesValue(
values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
hints = collective_util.Hints(timeout_seconds=1)
options = collective_util.Options(
timeout_seconds=1, implementation=implementation)
@def_function.function
def batch_reduce_sparse():
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
hints)
options)
# The collective should time out because we only launch it on worker-0,
# while there're two workers in total.

View File

@ -2182,7 +2182,7 @@ class StrategyExtendedV2(object):
dst = device_util.current() or self._default_device or "/device:CPU:0"
return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
def reduce_to(self, reduce_op, value, destinations, options=None):
"""Combine (via e.g. sum or mean) values across replicas.
`reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
@ -2247,12 +2247,17 @@ class StrategyExtendedV2(object):
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, and this method doesn't update the
variable.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A tensor or value reduced to `destinations`.
"""
if options is None:
options = collective_util.Options()
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
@ -2260,17 +2265,12 @@ class StrategyExtendedV2(object):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
assert (reduce_op == reduce_util.ReduceOp.SUM or
reduce_op == reduce_util.ReduceOp.MEAN)
if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self._reduce_to(reduce_op, value, destinations, experimental_hints)
return self._reduce_to(reduce_op, value, destinations, options)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
raise NotImplementedError("must be implemented in descendants")
def batch_reduce_to(self,
reduce_op,
value_destination_pairs,
experimental_hints=None):
def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
"""Combine multiple `reduce_to` calls into one for faster execution.
Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
@ -2325,30 +2325,30 @@ class StrategyExtendedV2(object):
"SUM", "MEAN".
value_destination_pairs: a sequence of (value, destinations) pairs. See
`tf.distribute.Strategy.reduce_to` for descriptions.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A list of reduced values, one per pair in `value_destination_pairs`.
"""
if options is None:
options = collective_util.Options()
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self._batch_reduce_to(reduce_op, value_destination_pairs,
experimental_hints)
return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
return [
self.reduce_to(
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
self.reduce_to(reduce_op, t, destinations=v, options=options)
for t, v in value_destination_pairs
]
def _gather_to(self, value, destinations, axis, experimental_hints=None):
def _gather_to(self, value, destinations, axis, options=None):
"""Gather `value` across replicas along axis-th dimension to `destinations`.
`gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
@ -2365,31 +2365,30 @@ class StrategyExtendedV2(object):
variable.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
`tf.distribute.experimental.CollectiveHints` for details.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A tensor or value gathered to `destinations`.
"""
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self._gather_to_implementation(value, destinations, axis, experimental_hints)
if options is None:
options = collective_util.Options()
return self._gather_to_implementation(value, destinations, axis, options)
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
def _gather_to_implementation(self, value, destinations, axis, options):
raise NotImplementedError("_gather_to must be implemented in descendants")
def _batch_gather_to(self,
value_destination_pairs,
axis,
experimental_hints=None):
def _batch_gather_to(self, value_destination_pairs, axis, options=None):
_require_cross_replica_or_default_context_extended(self)
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if options is None:
options = collective_util.Options()
return [
self._gather_to(
t, destinations=v, axis=axis, experimental_hints=experimental_hints)
self._gather_to(t, destinations=v, axis=axis, options=options)
for t, v in value_destination_pairs
]
@ -2410,7 +2409,8 @@ class StrategyExtendedV2(object):
Example usage:
```python
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
devices
with strategy.scope():
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def update_fn(v):
@ -2975,7 +2975,7 @@ class ReplicaContext(object):
require_replica_context(self)
return (device_util.current(),)
def all_reduce(self, reduce_op, value, experimental_hints=None):
def all_reduce(self, reduce_op, value, options=None):
"""All-reduces `value` across all replicas.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@ -2988,7 +2988,7 @@ class ReplicaContext(object):
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
It supports batched operations. You can pass a list of values and it
attempts to batch them when possible. You can also specify `experimental_hints`
attempts to batch them when possible. You can also specify `options`
to indicate the desired batching behavior, e.g. batch the values into
multiple packs so that they can better overlap with computations.
@ -3028,8 +3028,11 @@ class ReplicaContext(object):
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
The structure and the shapes of the `tf.Tensor` need to be same on all
replicas.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
to perform collective operations.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A nested structure of `tf.Tensor` with the reduced values. The structure
@ -3037,13 +3040,13 @@ class ReplicaContext(object):
"""
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if options is None:
options = collective_util.Options()
def batch_all_reduce(strategy, *value_flat):
return strategy.extended.batch_reduce_to(
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
experimental_hints)
options)
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@ -3070,7 +3073,7 @@ class ReplicaContext(object):
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
# TODO(wxinyi): generate docs after it is implemented for all strategies.
def _all_gather(self, value, axis, experimental_hints=None):
def _all_gather(self, value, axis, options=None):
"""All-gathers `value` across all replicas along `axis`.
Note: An `all_gather` method can only be called in replica context. To find
@ -3147,8 +3150,11 @@ class ReplicaContext(object):
constructs can only be dense tensors with non-zero rank, NOT
`tf.IndexedSlices`.
axis: 0-D int32 Tensor. Dimension along which to gather.
experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
to perform collective operations.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A nested structure of `tf.Tensor` with the gathered values. The structure
@ -3159,13 +3165,13 @@ class ReplicaContext(object):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
if experimental_hints is None:
experimental_hints = collective_util.Hints()
if options is None:
options = collective_util.Options()
def batch_all_gather(strategy, *value_flat):
return strategy.extended._batch_gather_to( # pylint: disable=protected-access
[(v, _batch_reduce_destination(v)) for v in value_flat], axis,
experimental_hints)
options)
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
@ -3319,13 +3325,13 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
# TODO(josh11b): Use destinations?
del reduce_op, destinations, experimental_hints
del reduce_op, destinations, options
return value
def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
del destinations, axis, experimental_hints
def _gather_to_implementation(self, value, destinations, axis, options):
del destinations, axis, options
return value
def _update(self, var, fn, args, kwargs, group):

View File

@ -94,8 +94,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
def _local_results(self, value):
return (value,)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations, experimental_hints
def _reduce_to(self, reduce_op, value, destinations, options):
del reduce_op, destinations, options
return value
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import copy
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
@ -313,6 +314,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
assert devices, ("Got an empty `devices` list and unable to recognize "
"any local devices.")
self._cross_device_ops = cross_device_ops
self._communication_options = collective_util.Options()
self._initialize_strategy(devices)
# TODO(b/128995245): Enable last partial batch support in graph mode.
@ -632,8 +634,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
del value # Unused.
return self._cross_device_ops or self._inferred_cross_device_ops
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
def _gather_to_implementation(self, value, destinations, axis, options):
if not isinstance(value, values.DistributedValues):
# ReductionToOneDevice._gather accepts DistributedValues only.
return value
@ -641,9 +642,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
value,
destinations=destinations,
axis=axis,
experimental_hints=experimental_hints)
options=self._communication_options.merge(options))
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
if (distribute_utils.is_mirrored(value) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
@ -659,10 +660,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
options=self._communication_options.merge(options))
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
cross_device_ops = None
for value, _ in value_destination_pairs:
if cross_device_ops is None:
@ -670,8 +670,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
elif cross_device_ops is not self._get_cross_device_ops(value):
raise ValueError("inputs to batch_reduce_to must be either all on the "
"the host or all on the compute devices")
return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs,
experimental_hints)
return cross_device_ops.batch_reduce(
reduce_op,
value_destination_pairs,
options=self._communication_options.merge(options))
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.

View File

@ -379,13 +379,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
with ops.device(self._device), _OneDeviceReplicaContext(strategy):
return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations, experimental_hints
def _reduce_to(self, reduce_op, value, destinations, options):
del reduce_op, destinations, options
return value
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
del destinations, axis, experimental_hints
def _gather_to_implementation(self, value, destinations, axis, options):
del destinations, axis, options
return value
def _update(self, var, fn, args, kwargs, group):

View File

@ -504,7 +504,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
(d, self._worker_device))
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
options):
self._verify_destinations_not_different_worker(destinations)
if not isinstance(value, values.DistributedValues):
return value
@ -512,27 +512,22 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
value,
destinations=destinations,
axis=axis,
experimental_hints=experimental_hints)
options=options)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
self._verify_destinations_not_different_worker(destinations)
if not isinstance(value, values.DistributedValues):
# pylint: disable=protected-access
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync)
return self._cross_device_ops.reduce(
reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
reduce_op, value, destinations=destinations, options=options)
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_device_ops.batch_reduce(reduce_op,
value_destination_pairs,
experimental_hints)
value_destination_pairs, options)
def _select_single_value(self, structured):
"""Select any single value in `structured`."""

View File

@ -1022,8 +1022,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
def _gather_to_implementation(self, value, destinations, axis, options):
if not isinstance(value, values.DistributedValues):
return value
@ -1070,7 +1069,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return output
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
def _reduce_to(self, reduce_op, value, destinations, options):
if (isinstance(value, values.DistributedValues) or
tensor_util.is_tensor(value)
) and tpu_values.enclosing_tpu_context() is not None:
@ -1412,8 +1411,8 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
return self.strategy.extended.experimental_logical_device(logical_device_id)
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
def _all_gather(self, value, axis, experimental_hints=None):
del experimental_hints
def _all_gather(self, value, axis, options=None):
del options
for v in nest.flatten(value):
if isinstance(v, ops.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "

View File

@ -432,7 +432,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
NUM_WORKERS = 3
CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
CollectiveCommunication = collective_util.CollectiveCommunication
class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
@ -477,8 +477,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=devices,
group_size=len(devices),
collective_keys=collective_keys,
communication=communication)
collective_keys=collective_keys)
return collective_all_reduce_ops, devices, ""
else:
# NCCL requires physical GPUs for every replica, which we can't do with
@ -509,8 +508,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=devices,
group_size=len(devices) * NUM_WORKERS,
collective_keys=collective_keys,
communication=communication)
collective_keys=collective_keys)
return (collective_all_reduce_ops, devices,
"grpc://" + self._cluster_spec[task_type][task_id])

View File

@ -8,11 +8,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -24,10 +24,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -10,11 +10,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -10,11 +10,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -9,11 +9,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -25,10 +25,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -24,7 +24,7 @@ tf_class {
}
member_method {
name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "merge_call"

View File

@ -37,7 +37,7 @@ tf_class {
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "broadcast_to"
@ -69,7 +69,7 @@ tf_class {
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "update"

View File

@ -1,16 +1,16 @@
path: "tensorflow.distribute.experimental.CollectiveCommunication"
tf_class {
is_instance: "<enum \'CollectiveCommunication\'>"
is_instance: "<enum \'CommunicationImplemenation\'>"
member {
name: "AUTO"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "NCCL"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "RING"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
}

View File

@ -0,0 +1,16 @@
path: "tensorflow.distribute.experimental.CommunicationImplemenation"
tf_class {
is_instance: "<enum \'CommunicationImplemenation\'>"
member {
name: "AUTO"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "NCCL"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "RING"
mtype: "<enum \'CommunicationImplemenation\'>"
}
}

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.CommunicationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_util._OptionsExported\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplemenation.AUTO\'], "
}
}

View File

@ -18,7 +18,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], "
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplemenation.AUTO\', \'None\'], "
}
member_method {
name: "colocate_vars_with"

View File

@ -12,6 +12,14 @@ tf_module {
name: "CollectiveHints"
mtype: "<type \'type\'>"
}
member {
name: "CommunicationImplemenation"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "CommunicationOptions"
mtype: "<type \'type\'>"
}
member {
name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>"

View File

@ -8,11 +8,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -24,10 +24,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -10,11 +10,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -10,11 +10,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -9,11 +9,11 @@ tf_class {
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
@ -25,10 +25,10 @@ tf_class {
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -24,7 +24,7 @@ tf_class {
}
member_method {
name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "merge_call"

View File

@ -20,7 +20,7 @@ tf_class {
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "colocate_vars_with"
@ -28,7 +28,7 @@ tf_class {
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "update"

View File

@ -1,16 +1,16 @@
path: "tensorflow.distribute.experimental.CollectiveCommunication"
tf_class {
is_instance: "<enum \'CollectiveCommunication\'>"
is_instance: "<enum \'CommunicationImplemenation\'>"
member {
name: "AUTO"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "NCCL"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "RING"
mtype: "<enum \'CollectiveCommunication\'>"
mtype: "<enum \'CommunicationImplemenation\'>"
}
}

View File

@ -0,0 +1,16 @@
path: "tensorflow.distribute.experimental.CommunicationImplemenation"
tf_class {
is_instance: "<enum \'CommunicationImplemenation\'>"
member {
name: "AUTO"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "NCCL"
mtype: "<enum \'CommunicationImplemenation\'>"
}
member {
name: "RING"
mtype: "<enum \'CommunicationImplemenation\'>"
}
}

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.CommunicationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_util._OptionsExported\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplemenation.AUTO\'], "
}
}

View File

@ -18,7 +18,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], "
argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplemenation.AUTO\', \'None\'], "
}
member_method {
name: "colocate_vars_with"

View File

@ -12,6 +12,14 @@ tf_module {
name: "CollectiveHints"
mtype: "<type \'type\'>"
}
member {
name: "CommunicationImplemenation"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "CommunicationOptions"
mtype: "<type \'type\'>"
}
member {
name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>"